Upload 96 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +11 -0
- LICENSE +21 -0
- README.md +80 -12
- assets/intro.jpg +3 -0
- baseline_generate/ace_step/convert.py +160 -0
- baseline_generate/ace_step/infer.py +122 -0
- baseline_generate/diffrhythm2/batch_inference.sh +57 -0
- baseline_generate/diffrhythm2/batch_inference_en.sh +57 -0
- baseline_generate/diffrhythm2/inference.py +294 -0
- baseline_generate/diffrhythm2/inference.sh +10 -0
- baseline_generate/diffrhythm2/scripts/__pycache__/proce_song_enprompt.cpython-311.pyc +0 -0
- baseline_generate/diffrhythm2/scripts/__pycache__/proce_song_enprompt_justtag.cpython-311.pyc +0 -0
- baseline_generate/diffrhythm2/scripts/proce_song.py +92 -0
- baseline_generate/diffrhythm2/scripts/proce_song_enprompt.py +145 -0
- baseline_generate/diffrhythm2/scripts/proce_song_enprompt_justtag.py +143 -0
- baseline_generate/levo/__pycache__/generate.cpython-311.pyc +0 -0
- baseline_generate/levo/convert.py +110 -0
- baseline_generate/levo/generate.py +591 -0
- baseline_generate/mureka_o2/__pycache__/generate.cpython-311.pyc +0 -0
- baseline_generate/mureka_o2/generate.py +390 -0
- baseline_generate/suno/__pycache__/suno_4_5.cpython-311.pyc +0 -0
- baseline_generate/suno/__pycache__/suno_5.cpython-311.pyc +0 -0
- baseline_generate/suno/config.py +70 -0
- baseline_generate/suno/suno_4_5.py +766 -0
- baseline_generate/suno/suno_5.py +768 -0
- baseline_generate/yue/__pycache__/infer_batch.cpython-311.pyc +0 -0
- baseline_generate/yue/batch.sh +55 -0
- baseline_generate/yue/codecmanipulator.py +204 -0
- baseline_generate/yue/infer_batch.py +904 -0
- baseline_generate/yue/mmtokenizer.py +367 -0
- data_pipeline/lyrics_gene/__pycache__/filter_all_cn.cpython-311.pyc +0 -0
- data_pipeline/lyrics_gene/__pycache__/filter_all_en.cpython-311.pyc +0 -0
- data_pipeline/lyrics_gene/__pycache__/gen_lyrics_cn.cpython-311.pyc +0 -0
- data_pipeline/lyrics_gene/filter_all_cn.py +272 -0
- data_pipeline/lyrics_gene/filter_all_en.py +256 -0
- data_pipeline/lyrics_gene/gen_lyrics_cn.py +568 -0
- data_pipeline/lyrics_gene/gen_lyrics_en.py +577 -0
- data_pipeline/meta_process/convert_convs.py +98 -0
- data_pipeline/meta_process/convert_lyrics.py +180 -0
- data_pipeline/meta_process/convert_messages.py +593 -0
- data_pipeline/meta_process/convert_segments.py +93 -0
- data_pipeline/meta_process/evaluate_polyphones.py +62 -0
- data_pipeline/meta_process/filter.py +46 -0
- data_pipeline/meta_process/main.py +77 -0
- data_pipeline/meta_process/meta_endpoints.py +118 -0
- data_pipeline/meta_process/meta_lang.py +125 -0
- data_pipeline/meta_process/meta_phonemes.py +283 -0
- data_pipeline/meta_process/meta_tags.py +124 -0
- data_pipeline/meta_process/meta_vocal.py +141 -0
- data_pipeline/meta_process/my_tool.py +551 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/intro.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/muse_outputs/main_0.6b_0.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/muse_outputs/main_0.6b_1.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/muse_outputs/main_1.7b_2.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
examples/muse_outputs/main_8b_3.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
examples/train_inputs/suno_cn_0.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
examples/train_inputs/suno_cn_1.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
examples/train_inputs/suno_en_2.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
examples/train_inputs/suno_en_3.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
train/train_demo.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
train/val.jsonl filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 yuhui1038
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,12 +1,80 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Muse: Towards Reproducible Long-Form Song Generation with Fine-Grained Style Control
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
📄 <a href="https://arxiv.org/abs/2601.03973">Paper</a> • 📊 <a href="https://huggingface.co/datasets/bolshyC/Muse">Dataset</a> • 🤖 <a href="https://huggingface.co/bolshyC/models">Model</a> • 📚 <a href="#citation">Citation</a>
|
| 5 |
+
</p>
|
| 6 |
+
|
| 7 |
+
This repository is the official repository for "Muse: Towards Reproducible Long-Form Song Generation with Fine-Grained Style Control". In this repository, we provide the Muse model, training and inference scripts, pretrained checkpoints, and evaluation pipelines.
|
| 8 |
+
|
| 9 |
+
## News and Updates
|
| 10 |
+
|
| 11 |
+
* **2026.01.11 🔥**: We are excited to announce that all datasets and models are now fully open-sourced! 🎶 The complete training dataset (116k songs), pretrained model weights, training and evaluation code, and data pipeline are publicly available.
|
| 12 |
+
|
| 13 |
+
## Installation
|
| 14 |
+
|
| 15 |
+
**Requirements**: Python 3.10 is required.
|
| 16 |
+
|
| 17 |
+
To set up the environment for Muse:
|
| 18 |
+
|
| 19 |
+
- **For training**: Install the training framework:
|
| 20 |
+
```bash
|
| 21 |
+
pip install ms-swift -U
|
| 22 |
+
```
|
| 23 |
+
- **For inference**: Install vLLM:
|
| 24 |
+
```bash
|
| 25 |
+
pip install vllm
|
| 26 |
+
```
|
| 27 |
+
- **For audio encoding/decoding**: Some dependencies (e.g., `av`) require system-level packages. On Ubuntu/Debian, install FFmpeg 4.4+ first:
|
| 28 |
+
```bash
|
| 29 |
+
sudo apt-get update
|
| 30 |
+
sudo apt-get install -y software-properties-common
|
| 31 |
+
sudo add-apt-repository ppa:savoury1/ffmpeg4 -y
|
| 32 |
+
sudo apt-get update
|
| 33 |
+
sudo apt-get install -y pkg-config ffmpeg libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
|
| 34 |
+
```
|
| 35 |
+
We recommend creating a new conda environment with Python 3.10. **Note**: Since `omegaconf==2.0.6` is required and has compatibility issues with pip 24.1+, you need to downgrade pip first:
|
| 36 |
+
```bash
|
| 37 |
+
pip install "pip<24.1"
|
| 38 |
+
```
|
| 39 |
+
Then install dependencies:
|
| 40 |
+
```bash
|
| 41 |
+
pip install --default-timeout=1000 -r requirements_mucodec.txt
|
| 42 |
+
```
|
| 43 |
+
For more details, please refer to the [MuCodec](https://github.com/tencent-ailab/MuCodec) official repository.
|
| 44 |
+
|
| 45 |
+
- **For data pipeline and evaluation**: If you need to run data processing scripts (lyrics generation, metadata processing) or evaluation scripts, install additional dependencies:
|
| 46 |
+
```bash
|
| 47 |
+
pip install -r requirements_data_eval.txt
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Repository Structure
|
| 51 |
+
|
| 52 |
+
This repository contains the following main directories:
|
| 53 |
+
|
| 54 |
+
- **`train/`**: Training scripts and utilities for fine-tuning the Muse model. See [`train/README.md`](train/README.md) for details.
|
| 55 |
+
- **`infer/`**: Inference scripts for generating music with the Muse model. See [`infer/README.md`](infer/README.md) for details.
|
| 56 |
+
- **`eval_pipeline/`**: Evaluation scripts for assessing model performance (Mulan-T, PER, AudioBox, SongEval, etc.).
|
| 57 |
+
- **`data_pipeline/`**: Scripts for building and processing training data, including lyrics generation, metadata processing, and music generation utilities.
|
| 58 |
+
|
| 59 |
+
## Model Architecture
|
| 60 |
+
|
| 61 |
+
<p align="center">
|
| 62 |
+
<img src="assets/intro.jpg" width="800"/>
|
| 63 |
+
</p>
|
| 64 |
+
|
| 65 |
+
## Acknowledgments
|
| 66 |
+
|
| 67 |
+
We thank [Qwen3](https://github.com/QwenLM/Qwen3) for providing the base language model, [ms-swift](https://github.com/modelscope/ms-swift) for the training framework, and [MuCodec](https://github.com/tencent-ailab/MuCodec) for discrete audio tokenization.
|
| 68 |
+
|
| 69 |
+
## Citation
|
| 70 |
+
|
| 71 |
+
If you find our work useful, please cite our paper:
|
| 72 |
+
|
| 73 |
+
```bibtex
|
| 74 |
+
@article{jiang2026muse,
|
| 75 |
+
title={Muse: Towards Reproducible Long-Form Song Generation with Fine-Grained Style Control},
|
| 76 |
+
author={Jiang, Changhao and Chen, Jiahao and Xiang, Zhenghao and Yang, Zhixiong and Wang, Hanchen and Zhuang, Jiabao and Che, Xinmeng and Sun, Jiajun and Li, Hui and Cao, Yifei and others},
|
| 77 |
+
journal={arXiv preprint arXiv:2601.03973},
|
| 78 |
+
year={2026}
|
| 79 |
+
}
|
| 80 |
+
```
|
assets/intro.jpg
ADDED
|
Git LFS Details
|
baseline_generate/ace_step/convert.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Convert data to ACE-STEP acceptable format
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
import json
|
| 7 |
+
import random
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
random.seed(42)
|
| 11 |
+
|
| 12 |
+
def load_jsonl(path:str) -> list[dict]:
|
| 13 |
+
data = []
|
| 14 |
+
with open(path, 'r') as file:
|
| 15 |
+
for line in tqdm(file, desc=f"Loading {path}"):
|
| 16 |
+
data.append(json.loads(line))
|
| 17 |
+
return data
|
| 18 |
+
|
| 19 |
+
def save_jsonl(data:list, path:str):
|
| 20 |
+
with open(path, 'w', encoding='utf-8') as file:
|
| 21 |
+
for ele in tqdm(data, desc=f"Saving {path}"):
|
| 22 |
+
json.dump(ele, file, ensure_ascii=False)
|
| 23 |
+
file.write("\n")
|
| 24 |
+
|
| 25 |
+
START_STR = "Please generate a song in the following style:"
|
| 26 |
+
END_STR = "\nNext, I will tell you the requirements and lyrics"
|
| 27 |
+
|
| 28 |
+
def process_tag(content:str) -> str:
|
| 29 |
+
"""Process segment label"""
|
| 30 |
+
# Extract label
|
| 31 |
+
end = content.find("[desc:")
|
| 32 |
+
tag = content[1:end-1]
|
| 33 |
+
# Lowercase & remove numbers & remove parentheses
|
| 34 |
+
tag = tag.lower()
|
| 35 |
+
tag = re.sub(r'\d+', '', tag)
|
| 36 |
+
tag = re.sub(r'\([^)]*\)', '', tag).strip()
|
| 37 |
+
if tag == "pre-chorus":
|
| 38 |
+
tag = "chorus"
|
| 39 |
+
return f"[{tag}]"
|
| 40 |
+
|
| 41 |
+
def process_lyrics(content:str) -> str:
|
| 42 |
+
"""Process segment lyrics"""
|
| 43 |
+
# Extract lyrics
|
| 44 |
+
start = content.find("[lyrics:\n")
|
| 45 |
+
if start == -1:
|
| 46 |
+
return ""
|
| 47 |
+
end = content.find("][phoneme:")
|
| 48 |
+
lyric = content[start+len("[lyrics:\n"):end]
|
| 49 |
+
|
| 50 |
+
# Punctuation conversion
|
| 51 |
+
pattern = r'[,。",:;&—‘\'.\]\[()?\n-]'
|
| 52 |
+
lyric = re.sub(pattern, '\n', lyric)
|
| 53 |
+
while lyric.find("\n\n") != -1:
|
| 54 |
+
lyric = lyric.replace("\n\n", "\n")
|
| 55 |
+
if lyric.endswith('\n'):
|
| 56 |
+
lyric = lyric[:-1]
|
| 57 |
+
return lyric
|
| 58 |
+
|
| 59 |
+
def has_chinese(text) -> bool:
|
| 60 |
+
for char in text:
|
| 61 |
+
if '\u4e00' <= char <= '\u9fff': # Basic Chinese characters
|
| 62 |
+
return True
|
| 63 |
+
return False
|
| 64 |
+
|
| 65 |
+
def process_duration(lyrics:str):
|
| 66 |
+
if has_chinese(lyrics):
|
| 67 |
+
lyrics = lyrics.replace("\n", "")
|
| 68 |
+
length = len(lyrics)
|
| 69 |
+
else:
|
| 70 |
+
lyrics = lyrics.replace("\n", " ")
|
| 71 |
+
length = len(lyrics.split())
|
| 72 |
+
duration = random.randint(int(length * 0.4), int(length * 0.7))
|
| 73 |
+
return duration
|
| 74 |
+
|
| 75 |
+
def process_one(messages:list[dict]):
|
| 76 |
+
"""Process a conversation messages into input format, return gt_lyric and descriptions"""
|
| 77 |
+
# Overall style
|
| 78 |
+
style:str = messages[0]['content']
|
| 79 |
+
start = style.find(START_STR)
|
| 80 |
+
end = style.find(END_STR)
|
| 81 |
+
descriptions = style[start+len(START_STR):end]
|
| 82 |
+
|
| 83 |
+
# Line-by-line lyrics
|
| 84 |
+
all_lyrics = "[intro]\n\n"
|
| 85 |
+
pure_lyrics = ""
|
| 86 |
+
for message in messages[1:]:
|
| 87 |
+
if message['role'] == "assistant":
|
| 88 |
+
continue
|
| 89 |
+
content = message['content']
|
| 90 |
+
# Segment label
|
| 91 |
+
tag = process_tag(content)
|
| 92 |
+
# Segment lyrics
|
| 93 |
+
lyric = process_lyrics(content)
|
| 94 |
+
all_lyrics += f"{tag}\n{lyric}\n\n"
|
| 95 |
+
pure_lyrics += lyric
|
| 96 |
+
all_lyrics = all_lyrics[:-2]
|
| 97 |
+
|
| 98 |
+
# Duration
|
| 99 |
+
duration = process_duration(pure_lyrics)
|
| 100 |
+
|
| 101 |
+
obj = {
|
| 102 |
+
"prompt": descriptions,
|
| 103 |
+
"lyrics": all_lyrics,
|
| 104 |
+
"audio_duration": duration,
|
| 105 |
+
"infer_step": 60,
|
| 106 |
+
"guidance_scale": 15,
|
| 107 |
+
"scheduler_type": "euler",
|
| 108 |
+
"cfg_type": "apg",
|
| 109 |
+
"omega_scale": 10,
|
| 110 |
+
"guidance_interval": 0.5,
|
| 111 |
+
"guidance_interval_decay": 0,
|
| 112 |
+
"min_guidance_scale": 3,
|
| 113 |
+
"use_erg_tag": True,
|
| 114 |
+
"use_erg_lyric": True,
|
| 115 |
+
"use_erg_diffusion": True,
|
| 116 |
+
"oss_steps": [],
|
| 117 |
+
"actual_seeds": [
|
| 118 |
+
3299954530
|
| 119 |
+
]
|
| 120 |
+
}
|
| 121 |
+
return obj
|
| 122 |
+
|
| 123 |
+
def main():
|
| 124 |
+
path = "xxx/ACE-Step/data/inputs/messages.jsonl"
|
| 125 |
+
dataset = load_jsonl(path)
|
| 126 |
+
|
| 127 |
+
for id, ele in tqdm(enumerate(dataset), desc="Processing"):
|
| 128 |
+
messages = ele['messages']
|
| 129 |
+
data = process_one(messages)
|
| 130 |
+
path = f"./data/inputs/test_{id}.jsonl"
|
| 131 |
+
with open(path, 'w', encoding='utf-8') as file:
|
| 132 |
+
json.dump(data, file, ensure_ascii=False, indent=4)
|
| 133 |
+
|
| 134 |
+
def load_jsonl(path):
|
| 135 |
+
dataset = []
|
| 136 |
+
with open(path, 'r') as file:
|
| 137 |
+
for line in file:
|
| 138 |
+
dataset.append(json.loads(line))
|
| 139 |
+
return dataset
|
| 140 |
+
|
| 141 |
+
def save_jsonl(dataset, path):
|
| 142 |
+
with open(path, 'w', encoding='utf-8') as file:
|
| 143 |
+
for ele in dataset:
|
| 144 |
+
json.dump(ele, file, ensure_ascii=False)
|
| 145 |
+
file.write("\n")
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
# main()
|
| 149 |
+
dataset = load_jsonl("./data/outputs/lyrics_params.jsonl")
|
| 150 |
+
for ele in dataset:
|
| 151 |
+
path = ele['audio_path']
|
| 152 |
+
ele['extra'] = int(path[len("./data/outputs/test_"):-len(".wav")])
|
| 153 |
+
sorted_data = sorted(dataset, key=lambda x: x['extra'])
|
| 154 |
+
|
| 155 |
+
save_path = "./data/outputs/lyrics_params_.jsonl"
|
| 156 |
+
with open(save_path, 'w', encoding='utf-8') as file:
|
| 157 |
+
for ele in sorted_data:
|
| 158 |
+
del ele['extra']
|
| 159 |
+
json.dump(ele, file, ensure_ascii=False)
|
| 160 |
+
file.write("\n")
|
baseline_generate/ace_step/infer.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import click
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
from acestep.pipeline_ace_step import ACEStepPipeline
|
| 5 |
+
from acestep.data_sampler import DataSampler
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def sample_data(json_data):
|
| 9 |
+
return (
|
| 10 |
+
json_data["audio_duration"],
|
| 11 |
+
json_data["prompt"],
|
| 12 |
+
json_data["lyrics"],
|
| 13 |
+
json_data["infer_step"],
|
| 14 |
+
json_data["guidance_scale"],
|
| 15 |
+
json_data["scheduler_type"],
|
| 16 |
+
json_data["cfg_type"],
|
| 17 |
+
json_data["omega_scale"],
|
| 18 |
+
", ".join(map(str, json_data["actual_seeds"])),
|
| 19 |
+
json_data["guidance_interval"],
|
| 20 |
+
json_data["guidance_interval_decay"],
|
| 21 |
+
json_data["min_guidance_scale"],
|
| 22 |
+
json_data["use_erg_tag"],
|
| 23 |
+
json_data["use_erg_lyric"],
|
| 24 |
+
json_data["use_erg_diffusion"],
|
| 25 |
+
", ".join(map(str, json_data["oss_steps"])),
|
| 26 |
+
json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
|
| 27 |
+
(
|
| 28 |
+
json_data["guidance_scale_lyric"]
|
| 29 |
+
if "guidance_scale_lyric" in json_data
|
| 30 |
+
else 0.0
|
| 31 |
+
),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@click.command()
|
| 36 |
+
@click.option(
|
| 37 |
+
"--checkpoint_path", type=str, default="", help="Path to the checkpoint directory"
|
| 38 |
+
)
|
| 39 |
+
@click.option("--bf16", type=bool, default=True, help="Whether to use bfloat16")
|
| 40 |
+
@click.option(
|
| 41 |
+
"--torch_compile", type=bool, default=False, help="Whether to use torch compile"
|
| 42 |
+
)
|
| 43 |
+
@click.option(
|
| 44 |
+
"--cpu_offload", type=bool, default=False, help="Whether to use CPU offloading (only load current stage's model to GPU)"
|
| 45 |
+
)
|
| 46 |
+
@click.option(
|
| 47 |
+
"--overlapped_decode", type=bool, default=False, help="Whether to use overlapped decoding (run dcae and vocoder using sliding windows)"
|
| 48 |
+
)
|
| 49 |
+
@click.option("--device_id", type=int, default=0, help="Device ID to use")
|
| 50 |
+
@click.option("--output_path", type=str, default=None, help="Path to save the output")
|
| 51 |
+
def main(checkpoint_path, bf16, torch_compile, cpu_offload, overlapped_decode, device_id, output_path):
|
| 52 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
|
| 53 |
+
|
| 54 |
+
model_demo = ACEStepPipeline(
|
| 55 |
+
checkpoint_dir=checkpoint_path,
|
| 56 |
+
dtype="bfloat16" if bf16 else "float32",
|
| 57 |
+
torch_compile=torch_compile,
|
| 58 |
+
cpu_offload=cpu_offload,
|
| 59 |
+
overlapped_decode=overlapped_decode
|
| 60 |
+
)
|
| 61 |
+
print(model_demo)
|
| 62 |
+
|
| 63 |
+
data_sampler = DataSampler()
|
| 64 |
+
|
| 65 |
+
inputs_dir = "./data/inputs"
|
| 66 |
+
for id, name in enumerate(os.listdir(inputs_dir)):
|
| 67 |
+
if not name.startswith("test"):
|
| 68 |
+
continue
|
| 69 |
+
path = os.path.join(inputs_dir, name)
|
| 70 |
+
with open(path, 'r') as file:
|
| 71 |
+
json_data = json.load(file)
|
| 72 |
+
json_data = sample_data(json_data)
|
| 73 |
+
|
| 74 |
+
pure_name = os.path.splitext(name)[0]
|
| 75 |
+
output_path = f"./data/outputs/{pure_name}.wav"
|
| 76 |
+
if os.path.exists(output_path):
|
| 77 |
+
continue
|
| 78 |
+
(
|
| 79 |
+
audio_duration,
|
| 80 |
+
prompt,
|
| 81 |
+
lyrics,
|
| 82 |
+
infer_step,
|
| 83 |
+
guidance_scale,
|
| 84 |
+
scheduler_type,
|
| 85 |
+
cfg_type,
|
| 86 |
+
omega_scale,
|
| 87 |
+
manual_seeds,
|
| 88 |
+
guidance_interval,
|
| 89 |
+
guidance_interval_decay,
|
| 90 |
+
min_guidance_scale,
|
| 91 |
+
use_erg_tag,
|
| 92 |
+
use_erg_lyric,
|
| 93 |
+
use_erg_diffusion,
|
| 94 |
+
oss_steps,
|
| 95 |
+
guidance_scale_text,
|
| 96 |
+
guidance_scale_lyric,
|
| 97 |
+
) = json_data
|
| 98 |
+
|
| 99 |
+
model_demo(
|
| 100 |
+
audio_duration=audio_duration,
|
| 101 |
+
prompt=prompt,
|
| 102 |
+
lyrics=lyrics,
|
| 103 |
+
infer_step=infer_step,
|
| 104 |
+
guidance_scale=guidance_scale,
|
| 105 |
+
scheduler_type=scheduler_type,
|
| 106 |
+
cfg_type=cfg_type,
|
| 107 |
+
omega_scale=omega_scale,
|
| 108 |
+
manual_seeds=manual_seeds,
|
| 109 |
+
guidance_interval=guidance_interval,
|
| 110 |
+
guidance_interval_decay=guidance_interval_decay,
|
| 111 |
+
min_guidance_scale=min_guidance_scale,
|
| 112 |
+
use_erg_tag=use_erg_tag,
|
| 113 |
+
use_erg_lyric=use_erg_lyric,
|
| 114 |
+
use_erg_diffusion=use_erg_diffusion,
|
| 115 |
+
oss_steps=oss_steps,
|
| 116 |
+
guidance_scale_text=guidance_scale_text,
|
| 117 |
+
guidance_scale_lyric=guidance_scale_lyric,
|
| 118 |
+
save_path=output_path,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
main()
|
baseline_generate/diffrhythm2/batch_inference.sh
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
# Navigate to script directory to ensure relative paths are consistent
|
| 5 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 6 |
+
cd "$SCRIPT_DIR"
|
| 7 |
+
|
| 8 |
+
SONG_DIR="$SCRIPT_DIR/example/zh_songs"
|
| 9 |
+
|
| 10 |
+
if [ ! -d "$SONG_DIR" ]; then
|
| 11 |
+
echo "Song directory not found: $SONG_DIR"
|
| 12 |
+
exit 1
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
# Collect all song_*.jsonl files
|
| 16 |
+
shopt -s nullglob
|
| 17 |
+
SONG_FILES=("$SONG_DIR"/song_*.jsonl)
|
| 18 |
+
shopt -u nullglob
|
| 19 |
+
|
| 20 |
+
if [ ${#SONG_FILES[@]} -eq 0 ]; then
|
| 21 |
+
echo "No song_*.jsonl files in song directory"
|
| 22 |
+
exit 0
|
| 23 |
+
fi
|
| 24 |
+
|
| 25 |
+
export PYTHONPATH="${PYTHONPATH:-}:${SCRIPT_DIR}"
|
| 26 |
+
|
| 27 |
+
espeak-ng --version
|
| 28 |
+
|
| 29 |
+
# Reproducibility settings:
|
| 30 |
+
# - Fixed random seed SEED
|
| 31 |
+
# - DO_SAMPLE=0 tries to follow deterministic path (including fixed style prompt cropping start)
|
| 32 |
+
SEED="${SEED:-42}"
|
| 33 |
+
DO_SAMPLE="${DO_SAMPLE:-0}"
|
| 34 |
+
|
| 35 |
+
# Further reduce cuBLAS non-determinism (enable when needed; comment out if causes errors)
|
| 36 |
+
export CUBLAS_WORKSPACE_CONFIG="${CUBLAS_WORKSPACE_CONFIG:-:4096:8}"
|
| 37 |
+
|
| 38 |
+
for SONG_FILE in "${SONG_FILES[@]}"; do
|
| 39 |
+
SONG_NAME="$(basename "$SONG_FILE")"
|
| 40 |
+
INPUT_PATH="./example/zh_songs/${SONG_NAME}"
|
| 41 |
+
echo "=============================="
|
| 42 |
+
echo "Starting generation: ${SONG_NAME}"
|
| 43 |
+
CMD=(python inference.py
|
| 44 |
+
--repo-id ASLP-lab/DiffRhythm2
|
| 45 |
+
--output-dir ./results/zh
|
| 46 |
+
--input-jsonl "$INPUT_PATH"
|
| 47 |
+
--cfg-strength 3.0
|
| 48 |
+
--max-secs 285.0
|
| 49 |
+
--seed "$SEED"
|
| 50 |
+
)
|
| 51 |
+
if [ "$DO_SAMPLE" -eq 1 ]; then
|
| 52 |
+
CMD+=(--do-sample)
|
| 53 |
+
fi
|
| 54 |
+
"${CMD[@]}"
|
| 55 |
+
done
|
| 56 |
+
|
| 57 |
+
echo "All songs generation complete, processed ${#SONG_FILES[@]} songs."
|
baseline_generate/diffrhythm2/batch_inference_en.sh
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
# Navigate to script directory to ensure relative paths are consistent
|
| 5 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 6 |
+
cd "$SCRIPT_DIR"
|
| 7 |
+
|
| 8 |
+
SONG_DIR="$SCRIPT_DIR/example/en_songs"
|
| 9 |
+
|
| 10 |
+
if [ ! -d "$SONG_DIR" ]; then
|
| 11 |
+
echo "Song directory not found: $SONG_DIR"
|
| 12 |
+
exit 1
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
# Collect all song_*.jsonl files
|
| 16 |
+
shopt -s nullglob
|
| 17 |
+
SONG_FILES=("$SONG_DIR"/song_*.jsonl)
|
| 18 |
+
shopt -u nullglob
|
| 19 |
+
|
| 20 |
+
if [ ${#SONG_FILES[@]} -eq 0 ]; then
|
| 21 |
+
echo "No song_*.jsonl files in song directory"
|
| 22 |
+
exit 0
|
| 23 |
+
fi
|
| 24 |
+
|
| 25 |
+
export PYTHONPATH="${PYTHONPATH:-}:${SCRIPT_DIR}"
|
| 26 |
+
|
| 27 |
+
espeak-ng --version
|
| 28 |
+
|
| 29 |
+
# Reproducibility settings:
|
| 30 |
+
# - Fixed random seed SEED
|
| 31 |
+
# - DO_SAMPLE=0 tries to follow deterministic path (including fixed style prompt cropping start)
|
| 32 |
+
SEED="${SEED:-42}"
|
| 33 |
+
DO_SAMPLE="${DO_SAMPLE:-0}"
|
| 34 |
+
|
| 35 |
+
# Further reduce cuBLAS non-determinism (enable when needed; comment out if causes errors)
|
| 36 |
+
export CUBLAS_WORKSPACE_CONFIG="${CUBLAS_WORKSPACE_CONFIG:-:4096:8}"
|
| 37 |
+
|
| 38 |
+
for SONG_FILE in "${SONG_FILES[@]}"; do
|
| 39 |
+
SONG_NAME="$(basename "$SONG_FILE")"
|
| 40 |
+
INPUT_PATH="./example/en_songs/${SONG_NAME}"
|
| 41 |
+
echo "=============================="
|
| 42 |
+
echo "Starting generation: ${SONG_NAME}"
|
| 43 |
+
CMD=(python inference.py
|
| 44 |
+
--repo-id ASLP-lab/DiffRhythm2
|
| 45 |
+
--output-dir ./results/en
|
| 46 |
+
--input-jsonl "$INPUT_PATH"
|
| 47 |
+
--cfg-strength 3.0
|
| 48 |
+
--max-secs 285.0
|
| 49 |
+
--seed "$SEED"
|
| 50 |
+
)
|
| 51 |
+
if [ "$DO_SAMPLE" -eq 1 ]; then
|
| 52 |
+
CMD+=(--do-sample)
|
| 53 |
+
fi
|
| 54 |
+
"${CMD[@]}"
|
| 55 |
+
done
|
| 56 |
+
|
| 57 |
+
echo "All songs generation complete, processed ${#SONG_FILES[@]} songs."
|
baseline_generate/diffrhythm2/inference.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torchaudio
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
import random
|
| 22 |
+
import pedalboard
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
from muq import MuQMuLan
|
| 26 |
+
from diffrhythm2.cfm import CFM
|
| 27 |
+
from diffrhythm2.backbones.dit import DiT
|
| 28 |
+
from bigvgan.model import Generator
|
| 29 |
+
from huggingface_hub import hf_hub_download
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
STRUCT_INFO = {
|
| 33 |
+
"[start]": 500,
|
| 34 |
+
"[end]": 501,
|
| 35 |
+
"[intro]": 502,
|
| 36 |
+
"[verse]": 503,
|
| 37 |
+
"[chorus]": 504,
|
| 38 |
+
"[outro]": 505,
|
| 39 |
+
"[inst]": 506,
|
| 40 |
+
"[solo]": 507,
|
| 41 |
+
"[bridge]": 508,
|
| 42 |
+
"[hook]": 509,
|
| 43 |
+
"[break]": 510,
|
| 44 |
+
"[stop]": 511,
|
| 45 |
+
"[space]": 512
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
lrc_tokenizer = None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def set_seed(seed: int, deterministic: bool = True):
|
| 52 |
+
random.seed(seed)
|
| 53 |
+
np.random.seed(seed)
|
| 54 |
+
torch.manual_seed(seed)
|
| 55 |
+
if torch.cuda.is_available():
|
| 56 |
+
torch.cuda.manual_seed_all(seed)
|
| 57 |
+
|
| 58 |
+
if deterministic:
|
| 59 |
+
# best-effort deterministic behavior; some ops may still be nondeterministic on certain GPUs/kernels
|
| 60 |
+
torch.backends.cudnn.deterministic = True
|
| 61 |
+
torch.backends.cudnn.benchmark = False
|
| 62 |
+
try:
|
| 63 |
+
torch.use_deterministic_algorithms(True, warn_only=True)
|
| 64 |
+
except Exception:
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
class CNENTokenizer():
|
| 68 |
+
def __init__(self):
|
| 69 |
+
curr_path = os.path.abspath(__file__)
|
| 70 |
+
vocab_path = os.path.join(os.path.dirname(curr_path), "g2p/g2p/vocab.json")
|
| 71 |
+
with open(vocab_path, 'r') as file:
|
| 72 |
+
self.phone2id:dict = json.load(file)['vocab']
|
| 73 |
+
self.id2phone = {v:k for (k, v) in self.phone2id.items()}
|
| 74 |
+
from g2p.g2p_generation import chn_eng_g2p
|
| 75 |
+
self.tokenizer = chn_eng_g2p
|
| 76 |
+
def encode(self, text):
|
| 77 |
+
phone, token = self.tokenizer(text)
|
| 78 |
+
token = [x+1 for x in token]
|
| 79 |
+
return token
|
| 80 |
+
def decode(self, token):
|
| 81 |
+
return "|".join([self.id2phone[x-1] for x in token])
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def prepare_model(repo_id, device):
|
| 85 |
+
diffrhythm2_ckpt_path = hf_hub_download(
|
| 86 |
+
repo_id=repo_id,
|
| 87 |
+
filename="model.safetensors",
|
| 88 |
+
local_dir="./ckpt",
|
| 89 |
+
local_files_only=False,
|
| 90 |
+
)
|
| 91 |
+
diffrhythm2_config_path = hf_hub_download(
|
| 92 |
+
repo_id=repo_id,
|
| 93 |
+
filename="config.json",
|
| 94 |
+
local_dir="./ckpt",
|
| 95 |
+
local_files_only=False,
|
| 96 |
+
)
|
| 97 |
+
with open(diffrhythm2_config_path) as f:
|
| 98 |
+
model_config = json.load(f)
|
| 99 |
+
|
| 100 |
+
model_config['use_flex_attn'] = False
|
| 101 |
+
diffrhythm2 = CFM(
|
| 102 |
+
transformer=DiT(
|
| 103 |
+
**model_config
|
| 104 |
+
),
|
| 105 |
+
num_channels=model_config['mel_dim'],
|
| 106 |
+
block_size=model_config['block_size'],
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
total_params = sum(p.numel() for p in diffrhythm2.parameters())
|
| 110 |
+
|
| 111 |
+
diffrhythm2 = diffrhythm2.to(device)
|
| 112 |
+
if diffrhythm2_ckpt_path.endswith('.safetensors'):
|
| 113 |
+
from safetensors.torch import load_file
|
| 114 |
+
ckpt = load_file(diffrhythm2_ckpt_path)
|
| 115 |
+
else:
|
| 116 |
+
ckpt = torch.load(diffrhythm2_ckpt_path, map_location='cpu')
|
| 117 |
+
diffrhythm2.load_state_dict(ckpt)
|
| 118 |
+
print(f"Total params: {total_params:,}")
|
| 119 |
+
|
| 120 |
+
# load Mulan
|
| 121 |
+
mulan = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./ckpt").to(device)
|
| 122 |
+
|
| 123 |
+
# load frontend
|
| 124 |
+
lrc_tokenizer = CNENTokenizer()
|
| 125 |
+
|
| 126 |
+
# load decoder
|
| 127 |
+
decoder_ckpt_path = hf_hub_download(
|
| 128 |
+
repo_id=repo_id,
|
| 129 |
+
filename="decoder.bin",
|
| 130 |
+
local_dir="./ckpt",
|
| 131 |
+
local_files_only=False,
|
| 132 |
+
)
|
| 133 |
+
decoder_config_path = hf_hub_download(
|
| 134 |
+
repo_id=repo_id,
|
| 135 |
+
filename="decoder.json",
|
| 136 |
+
local_dir="./ckpt",
|
| 137 |
+
local_files_only=False,
|
| 138 |
+
)
|
| 139 |
+
decoder = Generator(decoder_config_path, decoder_ckpt_path)
|
| 140 |
+
decoder = decoder.to(device)
|
| 141 |
+
return diffrhythm2, mulan, lrc_tokenizer, decoder
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def parse_lyrics(lyrics: str):
|
| 145 |
+
lyrics_with_time = []
|
| 146 |
+
lyrics = lyrics.split("\n")
|
| 147 |
+
for line in lyrics:
|
| 148 |
+
struct_idx = STRUCT_INFO.get(line, None)
|
| 149 |
+
if struct_idx is not None:
|
| 150 |
+
lyrics_with_time.append([struct_idx, STRUCT_INFO['[stop]']])
|
| 151 |
+
else:
|
| 152 |
+
tokens = lrc_tokenizer.encode(line.strip())
|
| 153 |
+
tokens = tokens + [STRUCT_INFO['[stop]']]
|
| 154 |
+
lyrics_with_time.append(tokens)
|
| 155 |
+
return lyrics_with_time
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def make_fake_stereo(audio, sampling_rate):
|
| 159 |
+
left_channel = audio
|
| 160 |
+
right_channel = audio.copy()
|
| 161 |
+
right_channel = right_channel * 0.8
|
| 162 |
+
delay_samples = int(0.01 * sampling_rate)
|
| 163 |
+
right_channel = np.roll(right_channel, delay_samples)
|
| 164 |
+
right_channel[:,:delay_samples] = 0
|
| 165 |
+
stereo_audio = np.concatenate([left_channel, right_channel], axis=0)
|
| 166 |
+
|
| 167 |
+
return stereo_audio
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def inference(
|
| 171 |
+
model,
|
| 172 |
+
decoder,
|
| 173 |
+
text,
|
| 174 |
+
style_prompt,
|
| 175 |
+
duration,
|
| 176 |
+
output_dir,
|
| 177 |
+
song_name,
|
| 178 |
+
cfg_strength,
|
| 179 |
+
sample_steps=32,
|
| 180 |
+
process_bar=True,
|
| 181 |
+
fake_stereo=True,
|
| 182 |
+
):
|
| 183 |
+
with torch.inference_mode():
|
| 184 |
+
latent = model.sample_block_cache(
|
| 185 |
+
text=text.unsqueeze(0),
|
| 186 |
+
duration=int(duration * 5),
|
| 187 |
+
style_prompt=style_prompt.unsqueeze(0),
|
| 188 |
+
steps=sample_steps,
|
| 189 |
+
cfg_strength=cfg_strength,
|
| 190 |
+
process_bar=process_bar,
|
| 191 |
+
)
|
| 192 |
+
latent = latent.transpose(1, 2)
|
| 193 |
+
audio = decoder.decode_audio(latent, overlap=5, chunk_size=20)
|
| 194 |
+
|
| 195 |
+
basename = f"{song_name}.mp3"
|
| 196 |
+
output_path = os.path.join(output_dir, basename)
|
| 197 |
+
|
| 198 |
+
num_channels = 1
|
| 199 |
+
audio = audio.float().cpu().numpy().squeeze()[None, :]
|
| 200 |
+
if fake_stereo:
|
| 201 |
+
audio = make_fake_stereo(audio, decoder.h.sampling_rate)
|
| 202 |
+
num_channels = 2
|
| 203 |
+
|
| 204 |
+
with pedalboard.io.AudioFile(output_path, "w", decoder.h.sampling_rate, num_channels) as f:
|
| 205 |
+
f.write(audio)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if __name__ == "__main__":
|
| 209 |
+
|
| 210 |
+
parser = argparse.ArgumentParser()
|
| 211 |
+
|
| 212 |
+
parser.add_argument('--repo-id', type=str, default=None)
|
| 213 |
+
parser.add_argument('--output-dir', type=str, default=None)
|
| 214 |
+
parser.add_argument('--input-jsonl', type=str, default=None)
|
| 215 |
+
parser.add_argument('--cfg-strength', type=float, default=2.0)
|
| 216 |
+
parser.add_argument('--max-secs', type=float, default=210.0)
|
| 217 |
+
parser.add_argument('--steps', type=int, default=16)
|
| 218 |
+
parser.add_argument('--fake-stereo', type=bool, default=True)
|
| 219 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 220 |
+
parser.add_argument('--do-sample', action='store_true', default=False)
|
| 221 |
+
|
| 222 |
+
args = parser.parse_args()
|
| 223 |
+
|
| 224 |
+
output_dir = args.output_dir
|
| 225 |
+
input_jsonl = args.input_jsonl
|
| 226 |
+
cfg_strength = args.cfg_strength
|
| 227 |
+
max_secs = args.max_secs
|
| 228 |
+
device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
|
| 229 |
+
dtype = torch.float16
|
| 230 |
+
|
| 231 |
+
# reproducibility
|
| 232 |
+
set_seed(args.seed, deterministic=(not args.do_sample))
|
| 233 |
+
|
| 234 |
+
# load diffrhythm2
|
| 235 |
+
diffrhythm2, mulan, lrc_tokenizer, decoder = prepare_model(args.repo_id, device)
|
| 236 |
+
|
| 237 |
+
output_dir = args.output_dir
|
| 238 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 239 |
+
|
| 240 |
+
with open(input_jsonl, 'r') as f:
|
| 241 |
+
input_info = [json.loads(i.strip()) for i in f.readlines()]
|
| 242 |
+
|
| 243 |
+
for i in tqdm(range(len(input_info))):
|
| 244 |
+
info = input_info[i]
|
| 245 |
+
song_name = info.get('song_name', f"{i:04d}")
|
| 246 |
+
lyrics = info.get('lyrics', None)
|
| 247 |
+
style_prompt = info.get('style_prompt', None)
|
| 248 |
+
if lyrics is None or style_prompt is None:
|
| 249 |
+
print(f"lyrics or style_prompt is None, skip {song_name}")
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
# preprocess lyrics
|
| 253 |
+
with open(lyrics, 'r') as f:
|
| 254 |
+
lyrics = f.read()
|
| 255 |
+
lyrics_token = parse_lyrics(lyrics)
|
| 256 |
+
lyrics_token = torch.tensor(sum(lyrics_token, []), dtype=torch.long, device=device)
|
| 257 |
+
|
| 258 |
+
# preprocess style prompt
|
| 259 |
+
if os.path.isfile(style_prompt):
|
| 260 |
+
prompt_wav, sr = torchaudio.load(style_prompt)
|
| 261 |
+
prompt_wav = torchaudio.functional.resample(prompt_wav.to(device), sr, 24000)
|
| 262 |
+
if prompt_wav.shape[1] > 24000 * 10:
|
| 263 |
+
if args.do_sample:
|
| 264 |
+
start = random.randint(0, prompt_wav.shape[1] - 24000 * 10)
|
| 265 |
+
else:
|
| 266 |
+
start = 0
|
| 267 |
+
prompt_wav = prompt_wav[:, start:start+24000*10]
|
| 268 |
+
prompt_wav = prompt_wav.mean(dim=0, keepdim=True)
|
| 269 |
+
with torch.no_grad():
|
| 270 |
+
style_prompt_embed = mulan(wavs = prompt_wav)
|
| 271 |
+
else:
|
| 272 |
+
with torch.no_grad():
|
| 273 |
+
style_prompt_embed = mulan(texts = [style_prompt])
|
| 274 |
+
style_prompt_embed = style_prompt_embed.to(device).squeeze(0)
|
| 275 |
+
|
| 276 |
+
if device.type != 'cpu':
|
| 277 |
+
diffrhythm2 = diffrhythm2.half()
|
| 278 |
+
decoder = decoder.half()
|
| 279 |
+
style_prompt_embed = style_prompt_embed.half()
|
| 280 |
+
|
| 281 |
+
inference(
|
| 282 |
+
model=diffrhythm2,
|
| 283 |
+
decoder=decoder,
|
| 284 |
+
text=lyrics_token,
|
| 285 |
+
style_prompt=style_prompt_embed,
|
| 286 |
+
duration=max_secs,
|
| 287 |
+
output_dir=output_dir,
|
| 288 |
+
song_name=song_name,
|
| 289 |
+
sample_steps=args.steps,
|
| 290 |
+
cfg_strength=cfg_strength,
|
| 291 |
+
fake_stereo=args.fake_stereo,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
|
baseline_generate/diffrhythm2/inference.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
export PYTHONPATH=$PYTHONPATH:$PWD
|
| 3 |
+
espeak-ng --version
|
| 4 |
+
|
| 5 |
+
python inference.py \
|
| 6 |
+
--repo-id ASLP-lab/DiffRhythm2 \
|
| 7 |
+
--output-dir ./results/test \
|
| 8 |
+
--input-jsonl ./example/song_1.jsonl \
|
| 9 |
+
--cfg-strength 3.0 \
|
| 10 |
+
--max-secs 285.0 \
|
baseline_generate/diffrhythm2/scripts/__pycache__/proce_song_enprompt.cpython-311.pyc
ADDED
|
Binary file (6.37 kB). View file
|
|
|
baseline_generate/diffrhythm2/scripts/__pycache__/proce_song_enprompt_justtag.cpython-311.pyc
ADDED
|
Binary file (6.5 kB). View file
|
|
|
baseline_generate/diffrhythm2/scripts/proce_song.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Process songs.jsonl to generate corresponding lrc files and jsonl files.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import List
|
| 13 |
+
|
| 14 |
+
INPUT_JSONL = Path("xxx/diffrhythm2/example/final_zh_test.jsonl")
|
| 15 |
+
OUTPUT_SONG_DIR = Path("xxx/diffrhythm2/example/zh_songs")
|
| 16 |
+
OUTPUT_LRC_DIR = Path("xxx/diffrhythm2/example/zh_lrc")
|
| 17 |
+
|
| 18 |
+
TIMESTAMP_PATTERN = re.compile(r"\[\d{2}:\d{2}(?:\.\d+)?\]")
|
| 19 |
+
STRUCTURE_PATTERN = re.compile(r"^\[[^\]]+\]$")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def normalize_structure(tag: str) -> str:
|
| 23 |
+
"""Convert structure tag to target format."""
|
| 24 |
+
tag_lower = tag.lower()
|
| 25 |
+
if tag_lower.startswith("verse"):
|
| 26 |
+
return "[verse]"
|
| 27 |
+
if "chorus" in tag_lower:
|
| 28 |
+
return "[chorus]"
|
| 29 |
+
if "bridge" in tag_lower:
|
| 30 |
+
return "[bridge]"
|
| 31 |
+
return f"[{tag_lower}]"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def transform_lyrics(raw_lyrics: str) -> List[str]:
|
| 35 |
+
"""Convert lyrics to LRC line list according to requirements."""
|
| 36 |
+
lines = ["[start]", "[intro]"]
|
| 37 |
+
for raw_line in raw_lyrics.splitlines():
|
| 38 |
+
line = raw_line.strip()
|
| 39 |
+
if not line:
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
# Process structure tags separately
|
| 43 |
+
if STRUCTURE_PATTERN.match(line) and not TIMESTAMP_PATTERN.match(line):
|
| 44 |
+
tag_content = line[1:-1].strip()
|
| 45 |
+
lines.append(normalize_structure(tag_content))
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
# Remove timestamps
|
| 49 |
+
text = TIMESTAMP_PATTERN.sub("", line).strip()
|
| 50 |
+
if not text:
|
| 51 |
+
continue
|
| 52 |
+
lines.append(text)
|
| 53 |
+
|
| 54 |
+
lines.append("[end]")
|
| 55 |
+
return lines
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def ensure_dirs() -> None:
|
| 59 |
+
OUTPUT_SONG_DIR.mkdir(parents=True, exist_ok=True)
|
| 60 |
+
OUTPUT_LRC_DIR.mkdir(parents=True, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def process_songs() -> None:
|
| 64 |
+
ensure_dirs()
|
| 65 |
+
with INPUT_JSONL.open("r", encoding="utf-8") as infile:
|
| 66 |
+
for idx, line in enumerate(infile, start=1):
|
| 67 |
+
line = line.strip()
|
| 68 |
+
if not line:
|
| 69 |
+
continue
|
| 70 |
+
data = json.loads(line)
|
| 71 |
+
description = data.get("description", "")
|
| 72 |
+
lyrics_raw = data.get("lyrics", "")
|
| 73 |
+
|
| 74 |
+
lrc_lines = transform_lyrics(lyrics_raw)
|
| 75 |
+
lrc_filename = f"song_{idx}.lrc"
|
| 76 |
+
lrc_path = OUTPUT_LRC_DIR / lrc_filename
|
| 77 |
+
lrc_path.write_text("\n".join(lrc_lines), encoding="utf-8")
|
| 78 |
+
|
| 79 |
+
song_base = f"song_{idx}"
|
| 80 |
+
song_filename = f"{song_base}.jsonl"
|
| 81 |
+
song_json_path = OUTPUT_SONG_DIR / song_filename
|
| 82 |
+
song_entry = {
|
| 83 |
+
"song_name": song_base,
|
| 84 |
+
"style_prompt": description,
|
| 85 |
+
"lyrics": f"example/zh_lrc/{lrc_filename}",
|
| 86 |
+
}
|
| 87 |
+
song_json_path.write_text(json.dumps(song_entry, ensure_ascii=False) + "\n", encoding="utf-8")
|
| 88 |
+
print(f"Processed song {idx}: {song_filename}")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
process_songs()
|
baseline_generate/diffrhythm2/scripts/proce_song_enprompt.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def extract_user_prompt(messages):
|
| 5 |
+
"""
|
| 6 |
+
Extract all content from messages where role is user, and concatenate them
|
| 7 |
+
When concatenating each content, check if the concatenated prompt length exceeds 1600,
|
| 8 |
+
if so, do not concatenate and skip subsequent segments to ensure paragraph integrity
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
messages: List of dictionaries, each containing role and content fields
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
Concatenated prompt string
|
| 15 |
+
"""
|
| 16 |
+
# Collect all user message content, but need to check length limit
|
| 17 |
+
user_contents = []
|
| 18 |
+
current_length = 0 # Current concatenated length
|
| 19 |
+
|
| 20 |
+
for msg in messages:
|
| 21 |
+
if msg.get("role") == "user":
|
| 22 |
+
content = msg.get("content", "")
|
| 23 |
+
if content:
|
| 24 |
+
# Calculate total length if this content is added
|
| 25 |
+
# Need to consider newline: if content already exists, need to add a newline
|
| 26 |
+
if user_contents:
|
| 27 |
+
# Content already exists, need to add newline and current content
|
| 28 |
+
new_length = current_length + 1 + len(content) # 1 is newline length
|
| 29 |
+
else:
|
| 30 |
+
# First content, no newline needed
|
| 31 |
+
new_length = len(content)
|
| 32 |
+
|
| 33 |
+
# If adding this content doesn't exceed 1600, add it
|
| 34 |
+
if new_length <= 1600:
|
| 35 |
+
user_contents.append(content)
|
| 36 |
+
current_length = new_length
|
| 37 |
+
else:
|
| 38 |
+
# Exceeds 1600, don't add this content and stop processing subsequent segments
|
| 39 |
+
break
|
| 40 |
+
|
| 41 |
+
# Concatenate all content with newlines
|
| 42 |
+
if user_contents:
|
| 43 |
+
prompt = "\n".join(user_contents)
|
| 44 |
+
return prompt
|
| 45 |
+
|
| 46 |
+
# If no user message found, return empty string
|
| 47 |
+
return ""
|
| 48 |
+
|
| 49 |
+
def update_song_file(file_path, new_prompt):
|
| 50 |
+
"""
|
| 51 |
+
Update style_prompt field in song file
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
file_path: Path to song file
|
| 55 |
+
new_prompt: New prompt content
|
| 56 |
+
"""
|
| 57 |
+
# Read file content
|
| 58 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 59 |
+
lines = [line.strip() for line in f if line.strip()]
|
| 60 |
+
|
| 61 |
+
if not lines:
|
| 62 |
+
print(f" Warning: File {file_path} is empty, skipping")
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
# Read first JSON data
|
| 66 |
+
try:
|
| 67 |
+
data = json.loads(lines[0])
|
| 68 |
+
# Update style_prompt field
|
| 69 |
+
data['style_prompt'] = new_prompt
|
| 70 |
+
|
| 71 |
+
# Write back to file
|
| 72 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 73 |
+
f.write(json.dumps(data, ensure_ascii=False) + '\n')
|
| 74 |
+
# If there's a second empty line, keep it
|
| 75 |
+
if len(lines) > 1:
|
| 76 |
+
f.write('\n')
|
| 77 |
+
|
| 78 |
+
print(f" ✓ Updated {file_path}")
|
| 79 |
+
except json.JSONDecodeError as e:
|
| 80 |
+
print(f" Error: JSON parsing failed {file_path}: {e}")
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f" Error: Failed to update file {file_path}: {e}")
|
| 83 |
+
|
| 84 |
+
def main():
|
| 85 |
+
# File paths
|
| 86 |
+
input_file = "xxx/diffrhythm2/scripts/test_messages.jsonl"
|
| 87 |
+
zh_songs_dir = "xxx/diffrhythm2/example/zh_songs"
|
| 88 |
+
en_songs_dir = "xxx/diffrhythm2/example/en_songs"
|
| 89 |
+
|
| 90 |
+
print(f"Reading file: {input_file}")
|
| 91 |
+
|
| 92 |
+
# Read all data
|
| 93 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 94 |
+
lines = [line.strip() for line in f if line.strip()]
|
| 95 |
+
|
| 96 |
+
print(f"Read {len(lines)} entries")
|
| 97 |
+
|
| 98 |
+
# Process each entry
|
| 99 |
+
for idx, line in enumerate(lines, 1):
|
| 100 |
+
try:
|
| 101 |
+
data = json.loads(line)
|
| 102 |
+
messages = data.get("messages", [])
|
| 103 |
+
|
| 104 |
+
# Extract prompt
|
| 105 |
+
prompt = extract_user_prompt(messages)
|
| 106 |
+
|
| 107 |
+
if not prompt:
|
| 108 |
+
print(f"Processing entry {idx}: No user content found, skipping")
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
# Determine if Chinese or English
|
| 112 |
+
if idx <= 50:
|
| 113 |
+
# First 50 entries: Chinese songs
|
| 114 |
+
song_num = idx
|
| 115 |
+
target_dir = zh_songs_dir
|
| 116 |
+
lang = "Chinese"
|
| 117 |
+
else:
|
| 118 |
+
# Entries 51-100: English songs
|
| 119 |
+
song_num = idx - 50 # 51->1, 52->2, ..., 100->50
|
| 120 |
+
target_dir = en_songs_dir
|
| 121 |
+
lang = "English"
|
| 122 |
+
|
| 123 |
+
# Build file path
|
| 124 |
+
song_file = os.path.join(target_dir, f"song_{song_num}.jsonl")
|
| 125 |
+
|
| 126 |
+
print(f"Processing entry {idx} ({lang}, song_{song_num})...")
|
| 127 |
+
print(f" Prompt length: {len(prompt)} characters")
|
| 128 |
+
|
| 129 |
+
# Update file
|
| 130 |
+
update_song_file(song_file, prompt)
|
| 131 |
+
|
| 132 |
+
except json.JSONDecodeError as e:
|
| 133 |
+
print(f"JSON parsing failed for entry {idx}: {e}")
|
| 134 |
+
continue
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"Error processing entry {idx}: {e}")
|
| 137 |
+
import traceback
|
| 138 |
+
traceback.print_exc()
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
print(f"\nProcessing complete! Processed {len(lines)} entries")
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
main()
|
| 145 |
+
|
baseline_generate/diffrhythm2/scripts/proce_song_enprompt_justtag.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def extract_user_prompt(messages):
|
| 5 |
+
"""
|
| 6 |
+
Extract the first user role content from messages list, extract style string
|
| 7 |
+
|
| 8 |
+
Format is usually: "Please generate a song in the following style: (style description)\n"
|
| 9 |
+
Only keep the style string part after the colon
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
messages: List of dictionaries, each containing role and content fields
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
Style string
|
| 16 |
+
"""
|
| 17 |
+
# Find first message with role user
|
| 18 |
+
for msg in messages:
|
| 19 |
+
if msg.get("role") == "user":
|
| 20 |
+
content = msg.get("content", "")
|
| 21 |
+
if content:
|
| 22 |
+
# Find position of "Please generate a song in the following style:"
|
| 23 |
+
style_prefix = "Please generate a song in the following style:"
|
| 24 |
+
style_index = content.find(style_prefix)
|
| 25 |
+
|
| 26 |
+
if style_index != -1:
|
| 27 |
+
# Find start position of content after colon
|
| 28 |
+
start_index = style_index + len(style_prefix)
|
| 29 |
+
# Find position of newline
|
| 30 |
+
newline_index = content.find("\n", start_index)
|
| 31 |
+
|
| 32 |
+
if newline_index != -1:
|
| 33 |
+
# Extract content from after colon to before newline
|
| 34 |
+
style_text = content[start_index:newline_index].strip()
|
| 35 |
+
else:
|
| 36 |
+
# If no newline, extract to end of string
|
| 37 |
+
style_text = content[start_index:].strip()
|
| 38 |
+
|
| 39 |
+
return style_text
|
| 40 |
+
else:
|
| 41 |
+
# If standard format not found, return empty string
|
| 42 |
+
return ""
|
| 43 |
+
|
| 44 |
+
# If no user message found, return empty string
|
| 45 |
+
return ""
|
| 46 |
+
|
| 47 |
+
def update_song_file(file_path, new_prompt):
|
| 48 |
+
"""
|
| 49 |
+
Update style_prompt field in song file
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
file_path: Path to song file
|
| 53 |
+
new_prompt: New prompt content
|
| 54 |
+
"""
|
| 55 |
+
# Read file content
|
| 56 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 57 |
+
lines = [line.strip() for line in f if line.strip()]
|
| 58 |
+
|
| 59 |
+
if not lines:
|
| 60 |
+
print(f" Warning: File {file_path} is empty, skipping")
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
# Read first JSON data
|
| 64 |
+
try:
|
| 65 |
+
data = json.loads(lines[0])
|
| 66 |
+
# Update style_prompt field
|
| 67 |
+
data['style_prompt'] = new_prompt
|
| 68 |
+
|
| 69 |
+
# Write back to file
|
| 70 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 71 |
+
f.write(json.dumps(data, ensure_ascii=False) + '\n')
|
| 72 |
+
# If there's a second empty line, keep it
|
| 73 |
+
if len(lines) > 1:
|
| 74 |
+
f.write('\n')
|
| 75 |
+
|
| 76 |
+
print(f" ✓ Updated {file_path}")
|
| 77 |
+
except json.JSONDecodeError as e:
|
| 78 |
+
print(f" Error: JSON parsing failed {file_path}: {e}")
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f" Error: Failed to update file {file_path}: {e}")
|
| 81 |
+
|
| 82 |
+
def main():
|
| 83 |
+
# File paths
|
| 84 |
+
input_file = "xxx/diffrhythm2/scripts/test_messages.jsonl"
|
| 85 |
+
zh_songs_dir = "xxx/diffrhythm2/example/zh_songs"
|
| 86 |
+
en_songs_dir = "xxx/diffrhythm2/example/en_songs"
|
| 87 |
+
|
| 88 |
+
print(f"Reading file: {input_file}")
|
| 89 |
+
|
| 90 |
+
# Read all data
|
| 91 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 92 |
+
lines = [line.strip() for line in f if line.strip()]
|
| 93 |
+
|
| 94 |
+
print(f"Read {len(lines)} entries")
|
| 95 |
+
|
| 96 |
+
# Process each entry
|
| 97 |
+
for idx, line in enumerate(lines, 1):
|
| 98 |
+
try:
|
| 99 |
+
data = json.loads(line)
|
| 100 |
+
messages = data.get("messages", [])
|
| 101 |
+
|
| 102 |
+
# Extract prompt
|
| 103 |
+
prompt = extract_user_prompt(messages)
|
| 104 |
+
|
| 105 |
+
if not prompt:
|
| 106 |
+
print(f"Processing entry {idx}: No user content found, skipping")
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
# Determine if Chinese or English
|
| 110 |
+
if idx <= 50:
|
| 111 |
+
# First 50 entries: Chinese songs
|
| 112 |
+
song_num = idx
|
| 113 |
+
target_dir = zh_songs_dir
|
| 114 |
+
lang = "Chinese"
|
| 115 |
+
else:
|
| 116 |
+
# Entries 51-100: English songs
|
| 117 |
+
song_num = idx - 50 # 51->1, 52->2, ..., 100->50
|
| 118 |
+
target_dir = en_songs_dir
|
| 119 |
+
lang = "English"
|
| 120 |
+
|
| 121 |
+
# Build file path
|
| 122 |
+
song_file = os.path.join(target_dir, f"song_{song_num}.jsonl")
|
| 123 |
+
|
| 124 |
+
print(f"Processing entry {idx} ({lang}, song_{song_num})...")
|
| 125 |
+
print(f" Prompt length: {len(prompt)} characters")
|
| 126 |
+
|
| 127 |
+
# Update file
|
| 128 |
+
update_song_file(song_file, prompt)
|
| 129 |
+
|
| 130 |
+
except json.JSONDecodeError as e:
|
| 131 |
+
print(f"JSON parsing failed for entry {idx}: {e}")
|
| 132 |
+
continue
|
| 133 |
+
except Exception as e:
|
| 134 |
+
print(f"Error processing entry {idx}: {e}")
|
| 135 |
+
import traceback
|
| 136 |
+
traceback.print_exc()
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
print(f"\nProcessing complete! Processed {len(lines)} entries")
|
| 140 |
+
|
| 141 |
+
if __name__ == "__main__":
|
| 142 |
+
main()
|
| 143 |
+
|
baseline_generate/levo/__pycache__/generate.cpython-311.pyc
ADDED
|
Binary file (37.1 kB). View file
|
|
|
baseline_generate/levo/convert.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Convert data to ACE-STEP acceptable format
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
import json
|
| 7 |
+
import random
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
random.seed(42)
|
| 11 |
+
|
| 12 |
+
def load_jsonl(path:str) -> list[dict]:
|
| 13 |
+
data = []
|
| 14 |
+
with open(path, 'r') as file:
|
| 15 |
+
for line in tqdm(file, desc=f"Loading {path}"):
|
| 16 |
+
data.append(json.loads(line))
|
| 17 |
+
return data
|
| 18 |
+
|
| 19 |
+
def save_jsonl(data:list, path:str):
|
| 20 |
+
with open(path, 'w', encoding='utf-8') as file:
|
| 21 |
+
for ele in tqdm(data, desc=f"Saving {path}"):
|
| 22 |
+
json.dump(ele, file, ensure_ascii=False)
|
| 23 |
+
file.write("\n")
|
| 24 |
+
|
| 25 |
+
START_STR = "Please generate a song in the following style:"
|
| 26 |
+
END_STR = "\nNext, I will tell you the requirements and lyrics"
|
| 27 |
+
|
| 28 |
+
def process_tag(content:str) -> str:
|
| 29 |
+
"""Process segment label"""
|
| 30 |
+
# Extract label
|
| 31 |
+
end = content.find("[desc:")
|
| 32 |
+
tag = content[1:end-1]
|
| 33 |
+
# Lowercase & remove numbers & remove parentheses
|
| 34 |
+
tag = tag.lower()
|
| 35 |
+
tag = re.sub(r'\d+', '', tag)
|
| 36 |
+
tag = re.sub(r'\([^)]*\)', '', tag).strip()
|
| 37 |
+
if tag == "pre-chorus":
|
| 38 |
+
tag = "chorus"
|
| 39 |
+
return f"[{tag}]"
|
| 40 |
+
|
| 41 |
+
def process_lyrics(content:str) -> str:
|
| 42 |
+
"""Process segment lyrics"""
|
| 43 |
+
# Extract lyrics
|
| 44 |
+
start = content.find("[lyrics:\n")
|
| 45 |
+
if start == -1:
|
| 46 |
+
return ""
|
| 47 |
+
end = content.find("][phoneme:")
|
| 48 |
+
lyric = content[start+len("[lyrics:\n"):end]
|
| 49 |
+
|
| 50 |
+
# Punctuation conversion
|
| 51 |
+
pattern = r'[,。",:;&—‘\'.\]\[()?\n-]'
|
| 52 |
+
lyric = re.sub(pattern, '.', lyric)
|
| 53 |
+
while lyric.find("..") != -1:
|
| 54 |
+
lyric = lyric.replace("..", ".")
|
| 55 |
+
if lyric.endswith('.'):
|
| 56 |
+
lyric = lyric[:-1]
|
| 57 |
+
return lyric
|
| 58 |
+
|
| 59 |
+
def random_size() -> str:
|
| 60 |
+
# Intro/outro length
|
| 61 |
+
sizes = ['short', 'medium', 'long']
|
| 62 |
+
return random.choice(sizes)
|
| 63 |
+
|
| 64 |
+
def process_one(messages:list[dict]):
|
| 65 |
+
"""Process a conversation messages into input format, return gt_lyric and descriptions"""
|
| 66 |
+
# Overall style
|
| 67 |
+
style:str = messages[0]['content']
|
| 68 |
+
start = style.find(START_STR)
|
| 69 |
+
end = style.find(END_STR)
|
| 70 |
+
descriptions = style[start+len(START_STR):end]
|
| 71 |
+
|
| 72 |
+
# Line-by-line lyrics
|
| 73 |
+
start_tag = "intro-" + random_size()
|
| 74 |
+
end_tag = "outro-" + random_size()
|
| 75 |
+
gt_lyric = f"[{start_tag}] ;"
|
| 76 |
+
for message in messages[1:]:
|
| 77 |
+
if message['role'] == "assistant":
|
| 78 |
+
continue
|
| 79 |
+
content = message['content']
|
| 80 |
+
# Segment label
|
| 81 |
+
tag = process_tag(content)
|
| 82 |
+
# Segment lyrics
|
| 83 |
+
lyric = process_lyrics(content)
|
| 84 |
+
if lyric == "" or tag.startswith("[outro"):
|
| 85 |
+
gt_lyric += f" [{end_tag}]"
|
| 86 |
+
break
|
| 87 |
+
gt_lyric += f" {tag} {lyric} ;"
|
| 88 |
+
if not gt_lyric.endswith(f" [{end_tag}]"):
|
| 89 |
+
gt_lyric += f" [{end_tag}]"
|
| 90 |
+
return descriptions, gt_lyric
|
| 91 |
+
|
| 92 |
+
def main():
|
| 93 |
+
path = "xxx/SongGeneration/data/inputs/test_messages.jsonl"
|
| 94 |
+
dataset = load_jsonl(path)
|
| 95 |
+
save_path = "xxx/SongGeneration/data/inputs/lyrics.jsonl"
|
| 96 |
+
|
| 97 |
+
with open(save_path, 'w', encoding='utf-8') as file:
|
| 98 |
+
for id, ele in tqdm(enumerate(dataset), desc="Processing"):
|
| 99 |
+
messages = ele['messages']
|
| 100 |
+
descriptions, gt_lyric = process_one(messages)
|
| 101 |
+
data = {
|
| 102 |
+
"idx": f"test_{id}",
|
| 103 |
+
"descriptions": descriptions,
|
| 104 |
+
"gt_lyric": gt_lyric
|
| 105 |
+
}
|
| 106 |
+
json.dump(data, file, ensure_ascii=False)
|
| 107 |
+
file.write("\n")
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
main()
|
baseline_generate/levo/generate.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from hmac import new
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
import time
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
import torchaudio
|
| 10 |
+
import numpy as np
|
| 11 |
+
from omegaconf import OmegaConf
|
| 12 |
+
from codeclm.models import builders
|
| 13 |
+
import gc
|
| 14 |
+
from codeclm.trainer.codec_song_pl import CodecLM_PL
|
| 15 |
+
from codeclm.models import CodecLM
|
| 16 |
+
from third_party.demucs.models.pretrained import get_model_from_yaml
|
| 17 |
+
import re
|
| 18 |
+
import subprocess
|
| 19 |
+
|
| 20 |
+
auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
|
| 21 |
+
|
| 22 |
+
def get_free_gpu() -> int:
|
| 23 |
+
"""Return the GPU ID with the least memory usage"""
|
| 24 |
+
cmd = "nvidia-smi --query-gpu=index,memory.free --format=csv,noheader,nounits"
|
| 25 |
+
result = subprocess.check_output(cmd.split()).decode().strip().split("\n")
|
| 26 |
+
|
| 27 |
+
free_list = []
|
| 28 |
+
for line in result:
|
| 29 |
+
idx, free_mem = line.split(",")
|
| 30 |
+
free_list.append((int(idx), int(free_mem))) # (GPU id, free memory MiB)
|
| 31 |
+
|
| 32 |
+
# Sort by remaining memory
|
| 33 |
+
free_list.sort(key=lambda x: x[1], reverse=True)
|
| 34 |
+
return free_list[0][0]
|
| 35 |
+
|
| 36 |
+
class Separator:
|
| 37 |
+
def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
|
| 38 |
+
gpu_id = get_free_gpu()
|
| 39 |
+
self.device = f"cuda:{gpu_id}"
|
| 40 |
+
print(f"Using {self.device}")
|
| 41 |
+
|
| 42 |
+
# if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
|
| 43 |
+
# self.device = torch.device(f"cuda:{gpu_id}")
|
| 44 |
+
# else:
|
| 45 |
+
# self.device = torch.device("cpu")
|
| 46 |
+
|
| 47 |
+
self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
|
| 48 |
+
|
| 49 |
+
def init_demucs_model(self, model_path, config_path):
|
| 50 |
+
model = get_model_from_yaml(config_path, model_path)
|
| 51 |
+
model.to(self.device)
|
| 52 |
+
model.eval()
|
| 53 |
+
return model
|
| 54 |
+
|
| 55 |
+
def load_audio(self, f):
|
| 56 |
+
a, fs = torchaudio.load(f)
|
| 57 |
+
if (fs != 48000):
|
| 58 |
+
a = torchaudio.functional.resample(a, fs, 48000)
|
| 59 |
+
if a.shape[-1] >= 48000*10:
|
| 60 |
+
a = a[..., :48000*10]
|
| 61 |
+
return a[:, 0:48000*10]
|
| 62 |
+
|
| 63 |
+
def run(self, audio_path, output_dir='tmp', ext=".flac"):
|
| 64 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 65 |
+
name, _ = os.path.splitext(os.path.split(audio_path)[-1])
|
| 66 |
+
output_paths = []
|
| 67 |
+
|
| 68 |
+
for stem in self.demucs_model.sources:
|
| 69 |
+
output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
|
| 70 |
+
if os.path.exists(output_path):
|
| 71 |
+
output_paths.append(output_path)
|
| 72 |
+
if len(output_paths) == 1: # 4
|
| 73 |
+
vocal_path = output_paths[0]
|
| 74 |
+
else:
|
| 75 |
+
drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
|
| 76 |
+
for path in [drums_path, bass_path, other_path]:
|
| 77 |
+
os.remove(path)
|
| 78 |
+
full_audio = self.load_audio(audio_path)
|
| 79 |
+
vocal_audio = self.load_audio(vocal_path)
|
| 80 |
+
bgm_audio = full_audio - vocal_audio
|
| 81 |
+
return full_audio, vocal_audio, bgm_audio
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def parse_args():
|
| 85 |
+
parser = argparse.ArgumentParser(description='Song Generation Script')
|
| 86 |
+
|
| 87 |
+
# Required parameters
|
| 88 |
+
parser.add_argument('--ckpt_path', type=str, required=True,
|
| 89 |
+
help='Path to the checkpoint directory containing config.yaml and model.pt')
|
| 90 |
+
parser.add_argument('--input_jsonl', type=str, required=True,
|
| 91 |
+
help='Path to input JSONL file containing generation tasks')
|
| 92 |
+
parser.add_argument('--save_dir', type=str, required=True,
|
| 93 |
+
help='Directory to save generated audio files and results')
|
| 94 |
+
# Optional parameters
|
| 95 |
+
parser.add_argument('--generate_type', type=str, default='mixed',
|
| 96 |
+
help='Type of generation: "vocal" or "bgm" or "separate" or "mixed" (default: "mixed")')
|
| 97 |
+
parser.add_argument('--use_flash_attn', action='store_true',
|
| 98 |
+
help='Whether to use flash attention (default: False)')
|
| 99 |
+
parser.add_argument('--low_mem', action='store_true',
|
| 100 |
+
help='Whether to use low memory mode (default: False)')
|
| 101 |
+
return parser.parse_args()
|
| 102 |
+
|
| 103 |
+
def generate(args):
|
| 104 |
+
torch.set_num_threads(1)
|
| 105 |
+
ckpt_path = args.ckpt_path
|
| 106 |
+
input_jsonl = args.input_jsonl
|
| 107 |
+
save_dir = args.save_dir
|
| 108 |
+
cfg_path = os.path.join(ckpt_path, 'config.yaml')
|
| 109 |
+
ckpt_path = os.path.join(ckpt_path, 'model.pt')
|
| 110 |
+
cfg = OmegaConf.load(cfg_path)
|
| 111 |
+
cfg.lm.use_flash_attn_2 = args.use_flash_attn
|
| 112 |
+
print(f"use_flash_attn: {args.use_flash_attn}")
|
| 113 |
+
cfg.mode = 'inference'
|
| 114 |
+
max_duration = cfg.max_dur
|
| 115 |
+
gen_type = args.generate_type
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
separator = Separator()
|
| 119 |
+
auto_prompt = torch.load('tools/new_prompt.pt')
|
| 120 |
+
audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
|
| 121 |
+
audio_tokenizer = audio_tokenizer.eval().cuda()
|
| 122 |
+
with open(input_jsonl, "r") as fp:
|
| 123 |
+
lines = fp.readlines()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
new_items = []
|
| 127 |
+
for line in lines:
|
| 128 |
+
item = json.loads(line)
|
| 129 |
+
target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
|
| 130 |
+
# get prompt audio
|
| 131 |
+
if "prompt_audio_path" in item:
|
| 132 |
+
assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
|
| 133 |
+
assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
|
| 134 |
+
with torch.no_grad():
|
| 135 |
+
pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
|
| 136 |
+
item['raw_pmt_wav'] = pmt_wav
|
| 137 |
+
item['raw_vocal_wav'] = vocal_wav
|
| 138 |
+
item['raw_bgm_wav'] = bgm_wav
|
| 139 |
+
if pmt_wav.dim() == 2:
|
| 140 |
+
pmt_wav = pmt_wav[None]
|
| 141 |
+
if pmt_wav.dim() != 3:
|
| 142 |
+
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
| 143 |
+
pmt_wav = list(pmt_wav)
|
| 144 |
+
if vocal_wav.dim() == 2:
|
| 145 |
+
vocal_wav = vocal_wav[None]
|
| 146 |
+
if vocal_wav.dim() != 3:
|
| 147 |
+
raise ValueError("Vocal wavs should have a shape [B, C, T].")
|
| 148 |
+
vocal_wav = list(vocal_wav)
|
| 149 |
+
if bgm_wav.dim() == 2:
|
| 150 |
+
bgm_wav = bgm_wav[None]
|
| 151 |
+
if bgm_wav.dim() != 3:
|
| 152 |
+
raise ValueError("BGM wavs should have a shape [B, C, T].")
|
| 153 |
+
bgm_wav = list(bgm_wav)
|
| 154 |
+
if type(pmt_wav) == list:
|
| 155 |
+
pmt_wav = torch.stack(pmt_wav, dim=0)
|
| 156 |
+
if type(vocal_wav) == list:
|
| 157 |
+
vocal_wav = torch.stack(vocal_wav, dim=0)
|
| 158 |
+
if type(bgm_wav) == list:
|
| 159 |
+
bgm_wav = torch.stack(bgm_wav, dim=0)
|
| 160 |
+
pmt_wav = pmt_wav
|
| 161 |
+
vocal_wav = vocal_wav
|
| 162 |
+
bgm_wav = bgm_wav
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
|
| 165 |
+
melody_is_wav = False
|
| 166 |
+
elif "auto_prompt_audio_type" in item:
|
| 167 |
+
assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
|
| 168 |
+
prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
|
| 169 |
+
pmt_wav = prompt_token[:,[0],:]
|
| 170 |
+
vocal_wav = prompt_token[:,[1],:]
|
| 171 |
+
bgm_wav = prompt_token[:,[2],:]
|
| 172 |
+
melody_is_wav = False
|
| 173 |
+
else:
|
| 174 |
+
pmt_wav = None
|
| 175 |
+
vocal_wav = None
|
| 176 |
+
bgm_wav = None
|
| 177 |
+
melody_is_wav = True
|
| 178 |
+
item['pmt_wav'] = pmt_wav
|
| 179 |
+
item['vocal_wav'] = vocal_wav
|
| 180 |
+
item['bgm_wav'] = bgm_wav
|
| 181 |
+
item['melody_is_wav'] = melody_is_wav
|
| 182 |
+
item["idx"] = f"{item['idx']}"
|
| 183 |
+
item["wav_path"] = target_wav_name
|
| 184 |
+
new_items.append(item)
|
| 185 |
+
|
| 186 |
+
del audio_tokenizer
|
| 187 |
+
del separator
|
| 188 |
+
|
| 189 |
+
torch.cuda.empty_cache()
|
| 190 |
+
|
| 191 |
+
if "audio_tokenizer_checkpoint_sep" in cfg.keys():
|
| 192 |
+
seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
|
| 193 |
+
else:
|
| 194 |
+
seperate_tokenizer = None
|
| 195 |
+
|
| 196 |
+
if seperate_tokenizer is not None:
|
| 197 |
+
seperate_tokenizer = seperate_tokenizer.eval().cuda()
|
| 198 |
+
|
| 199 |
+
for item in new_items:
|
| 200 |
+
if "prompt_audio_path" in item:
|
| 201 |
+
with torch.no_grad():
|
| 202 |
+
vocal_wav, bgm_wav = seperate_tokenizer.encode(item['vocal_wav'].cuda(), item['bgm_wav'].cuda())
|
| 203 |
+
item['vocal_wav'] = vocal_wav
|
| 204 |
+
item['bgm_wav'] = bgm_wav
|
| 205 |
+
|
| 206 |
+
torch.cuda.empty_cache()
|
| 207 |
+
audiolm = builders.get_lm_model(cfg)
|
| 208 |
+
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
| 209 |
+
audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
|
| 210 |
+
audiolm.load_state_dict(audiolm_state_dict, strict=False)
|
| 211 |
+
audiolm = audiolm.eval()
|
| 212 |
+
audiolm = audiolm.cuda().to(torch.float16)
|
| 213 |
+
|
| 214 |
+
model = CodecLM(name = "tmp",
|
| 215 |
+
lm = audiolm,
|
| 216 |
+
audiotokenizer = None,
|
| 217 |
+
max_duration = max_duration,
|
| 218 |
+
seperate_tokenizer = seperate_tokenizer,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
cfg_coef = 1.5 #25
|
| 222 |
+
temp = 0.9
|
| 223 |
+
top_k = 50
|
| 224 |
+
top_p = 0.0
|
| 225 |
+
record_tokens = True
|
| 226 |
+
record_window = 50
|
| 227 |
+
|
| 228 |
+
model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
|
| 229 |
+
top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
|
| 230 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 231 |
+
os.makedirs(save_dir + "/audios", exist_ok=True)
|
| 232 |
+
os.makedirs(save_dir + "/jsonl", exist_ok=True)
|
| 233 |
+
|
| 234 |
+
for item in new_items:
|
| 235 |
+
lyric = item["gt_lyric"]
|
| 236 |
+
descriptions = item["descriptions"] if "descriptions" in item else None
|
| 237 |
+
pmt_wav = item['pmt_wav']
|
| 238 |
+
vocal_wav = item['vocal_wav']
|
| 239 |
+
bgm_wav = item['bgm_wav']
|
| 240 |
+
melody_is_wav = item['melody_is_wav']
|
| 241 |
+
target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
generate_inp = {
|
| 245 |
+
'lyrics': [lyric.replace(" ", " ")],
|
| 246 |
+
'descriptions': [descriptions],
|
| 247 |
+
'melody_wavs': pmt_wav,
|
| 248 |
+
'vocal_wavs': vocal_wav,
|
| 249 |
+
'bgm_wavs': bgm_wav,
|
| 250 |
+
'melody_is_wav': melody_is_wav,
|
| 251 |
+
}
|
| 252 |
+
start_time = time.time()
|
| 253 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 254 |
+
with torch.no_grad():
|
| 255 |
+
tokens = model.generate(**generate_inp, return_tokens=True)
|
| 256 |
+
mid_time = time.time()
|
| 257 |
+
|
| 258 |
+
with torch.no_grad():
|
| 259 |
+
if 'raw_pmt_wav' in item:
|
| 260 |
+
if gen_type == 'separate':
|
| 261 |
+
wav_seperate = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='mixed')
|
| 262 |
+
wav_vocal = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='vocal')
|
| 263 |
+
wav_bgm = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='bgm')
|
| 264 |
+
elif gen_type == 'mixed':
|
| 265 |
+
wav_seperate = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type=gen_type)
|
| 266 |
+
else:
|
| 267 |
+
wav_seperate = model.generate_audio(tokens,chunked=True, gen_type=gen_type)
|
| 268 |
+
del item['raw_pmt_wav']
|
| 269 |
+
del item['raw_vocal_wav']
|
| 270 |
+
del item['raw_bgm_wav']
|
| 271 |
+
else:
|
| 272 |
+
if gen_type == 'separate':
|
| 273 |
+
wav_vocal = model.generate_audio(tokens, chunked=True, gen_type='vocal')
|
| 274 |
+
wav_bgm = model.generate_audio(tokens, chunked=True, gen_type='bgm')
|
| 275 |
+
wav_seperate = model.generate_audio(tokens, chunked=True, gen_type='mixed')
|
| 276 |
+
else:
|
| 277 |
+
wav_seperate = model.generate_audio(tokens, chunked=True, gen_type=gen_type)
|
| 278 |
+
del item['pmt_wav']
|
| 279 |
+
del item['vocal_wav']
|
| 280 |
+
del item['bgm_wav']
|
| 281 |
+
del item['melody_is_wav']
|
| 282 |
+
end_time = time.time()
|
| 283 |
+
if gen_type == 'separate':
|
| 284 |
+
torchaudio.save(target_wav_name.replace('.flac', '_vocal.flac'), wav_vocal[0].cpu().float(), cfg.sample_rate)
|
| 285 |
+
torchaudio.save(target_wav_name.replace('.flac', '_bgm.flac'), wav_bgm[0].cpu().float(), cfg.sample_rate)
|
| 286 |
+
torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
|
| 287 |
+
else:
|
| 288 |
+
torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
|
| 289 |
+
|
| 290 |
+
print(f"process{item['idx']}, lm cost {mid_time - start_time}s, diffusion cost {end_time - mid_time}")
|
| 291 |
+
item["idx"] = f"{item['idx']}"
|
| 292 |
+
item["wav_path"] = target_wav_name
|
| 293 |
+
|
| 294 |
+
src_jsonl_name = os.path.split(input_jsonl)[-1]
|
| 295 |
+
with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
|
| 296 |
+
for item in new_items:
|
| 297 |
+
fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
|
| 298 |
+
|
| 299 |
+
def generate_lowmem(args):
|
| 300 |
+
torch.set_num_threads(1)
|
| 301 |
+
ckpt_path = args.ckpt_path
|
| 302 |
+
input_jsonl = args.input_jsonl
|
| 303 |
+
save_dir = args.save_dir
|
| 304 |
+
cfg_path = os.path.join(ckpt_path, 'config.yaml')
|
| 305 |
+
ckpt_path = os.path.join(ckpt_path, 'model.pt')
|
| 306 |
+
cfg = OmegaConf.load(cfg_path)
|
| 307 |
+
cfg.lm.use_flash_attn_2 = args.use_flash_attn
|
| 308 |
+
print(f"use_flash_attn: {args.use_flash_attn}")
|
| 309 |
+
cfg.mode = 'inference'
|
| 310 |
+
max_duration = cfg.max_dur
|
| 311 |
+
gen_type = args.generate_type
|
| 312 |
+
chunk_size = 128
|
| 313 |
+
use_audio_tokenizer = False
|
| 314 |
+
with open(input_jsonl, "r") as fp:
|
| 315 |
+
lines = fp.readlines()
|
| 316 |
+
for line in lines:
|
| 317 |
+
item = json.loads(line)
|
| 318 |
+
if "prompt_audio_path" in item:
|
| 319 |
+
use_audio_tokenizer = True
|
| 320 |
+
break
|
| 321 |
+
if use_audio_tokenizer:
|
| 322 |
+
separator = Separator()
|
| 323 |
+
audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
|
| 324 |
+
audio_tokenizer = audio_tokenizer.eval().cuda()
|
| 325 |
+
auto_prompt = torch.load('tools/new_prompt.pt')
|
| 326 |
+
new_items = []
|
| 327 |
+
for line in lines:
|
| 328 |
+
item = json.loads(line)
|
| 329 |
+
target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
|
| 330 |
+
# get prompt audio
|
| 331 |
+
if "prompt_audio_path" in item:
|
| 332 |
+
assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
|
| 333 |
+
assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
|
| 334 |
+
with torch.no_grad():
|
| 335 |
+
pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
|
| 336 |
+
item['raw_pmt_wav'] = pmt_wav
|
| 337 |
+
item['raw_vocal_wav'] = vocal_wav
|
| 338 |
+
item['raw_bgm_wav'] = bgm_wav
|
| 339 |
+
if pmt_wav.dim() == 2:
|
| 340 |
+
pmt_wav = pmt_wav[None]
|
| 341 |
+
if pmt_wav.dim() != 3:
|
| 342 |
+
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
| 343 |
+
pmt_wav = list(pmt_wav)
|
| 344 |
+
if vocal_wav.dim() == 2:
|
| 345 |
+
vocal_wav = vocal_wav[None]
|
| 346 |
+
if vocal_wav.dim() != 3:
|
| 347 |
+
raise ValueError("Vocal wavs should have a shape [B, C, T].")
|
| 348 |
+
vocal_wav = list(vocal_wav)
|
| 349 |
+
if bgm_wav.dim() == 2:
|
| 350 |
+
bgm_wav = bgm_wav[None]
|
| 351 |
+
if bgm_wav.dim() != 3:
|
| 352 |
+
raise ValueError("BGM wavs should have a shape [B, C, T].")
|
| 353 |
+
bgm_wav = list(bgm_wav)
|
| 354 |
+
if type(pmt_wav) == list:
|
| 355 |
+
pmt_wav = torch.stack(pmt_wav, dim=0)
|
| 356 |
+
if type(vocal_wav) == list:
|
| 357 |
+
vocal_wav = torch.stack(vocal_wav, dim=0)
|
| 358 |
+
if type(bgm_wav) == list:
|
| 359 |
+
bgm_wav = torch.stack(bgm_wav, dim=0)
|
| 360 |
+
with torch.no_grad():
|
| 361 |
+
pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
|
| 362 |
+
melody_is_wav = False
|
| 363 |
+
elif "auto_prompt_audio_type" in item:
|
| 364 |
+
assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
|
| 365 |
+
prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
|
| 366 |
+
pmt_wav = prompt_token[:,[0],:]
|
| 367 |
+
vocal_wav = prompt_token[:,[1],:]
|
| 368 |
+
bgm_wav = prompt_token[:,[2],:]
|
| 369 |
+
melody_is_wav = False
|
| 370 |
+
else:
|
| 371 |
+
pmt_wav = None
|
| 372 |
+
vocal_wav = None
|
| 373 |
+
bgm_wav = None
|
| 374 |
+
melody_is_wav = True
|
| 375 |
+
item['pmt_wav'] = pmt_wav
|
| 376 |
+
item['vocal_wav'] = vocal_wav
|
| 377 |
+
item['bgm_wav'] = bgm_wav
|
| 378 |
+
item['melody_is_wav'] = melody_is_wav
|
| 379 |
+
item["idx"] = f"{item['idx']}"
|
| 380 |
+
item["wav_path"] = target_wav_name
|
| 381 |
+
new_items.append(item)
|
| 382 |
+
|
| 383 |
+
if use_audio_tokenizer:
|
| 384 |
+
del audio_tokenizer
|
| 385 |
+
del separator
|
| 386 |
+
|
| 387 |
+
torch.cuda.empty_cache()
|
| 388 |
+
|
| 389 |
+
if "audio_tokenizer_checkpoint_sep" in cfg.keys() and use_audio_tokenizer:
|
| 390 |
+
seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
|
| 391 |
+
else:
|
| 392 |
+
seperate_tokenizer = None
|
| 393 |
+
|
| 394 |
+
if seperate_tokenizer is not None:
|
| 395 |
+
seperate_tokenizer = seperate_tokenizer.eval().cuda()
|
| 396 |
+
|
| 397 |
+
for item in new_items:
|
| 398 |
+
if "prompt_audio_path" in item:
|
| 399 |
+
with torch.no_grad():
|
| 400 |
+
vocal_wav, bgm_wav = seperate_tokenizer.encode(item['vocal_wav'].cuda(), item['bgm_wav'].cuda())
|
| 401 |
+
item['vocal_wav'] = vocal_wav
|
| 402 |
+
item['bgm_wav'] = bgm_wav
|
| 403 |
+
|
| 404 |
+
if use_audio_tokenizer:
|
| 405 |
+
del seperate_tokenizer
|
| 406 |
+
|
| 407 |
+
torch.cuda.empty_cache()
|
| 408 |
+
|
| 409 |
+
# Define model or load pretrained model
|
| 410 |
+
audiolm = builders.get_lm_model(cfg)
|
| 411 |
+
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
| 412 |
+
audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
|
| 413 |
+
audiolm.load_state_dict(audiolm_state_dict, strict=False)
|
| 414 |
+
audiolm = audiolm.eval()
|
| 415 |
+
|
| 416 |
+
offload_audiolm = True if 'offload' in cfg.keys() and 'audiolm' in cfg.offload else False
|
| 417 |
+
if offload_audiolm:
|
| 418 |
+
audiolm_offload_param = OffloadParamParse.parse_config(audiolm, cfg.offload.audiolm)
|
| 419 |
+
audiolm_offload_param.show()
|
| 420 |
+
offload_profiler = OffloadProfiler(device_index=0, **(audiolm_offload_param.init_param_dict()))
|
| 421 |
+
offload_profiler.offload_layer(**(audiolm_offload_param.offload_layer_param_dict()))
|
| 422 |
+
offload_profiler.clean_cache_wrapper(**(audiolm_offload_param.clean_cache_param_dict()))
|
| 423 |
+
else:
|
| 424 |
+
audiolm = audiolm.cuda().to(torch.float16)
|
| 425 |
+
|
| 426 |
+
model = CodecLM(name = "tmp",
|
| 427 |
+
lm = audiolm,
|
| 428 |
+
audiotokenizer = None,
|
| 429 |
+
max_duration = max_duration,
|
| 430 |
+
seperate_tokenizer = None,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
cfg_coef = 1.5 #25
|
| 434 |
+
temp = 0.9
|
| 435 |
+
top_k = 50
|
| 436 |
+
top_p = 0.0
|
| 437 |
+
record_tokens = True
|
| 438 |
+
record_window = 50
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
|
| 442 |
+
top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
|
| 443 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 444 |
+
os.makedirs(save_dir + "/audios", exist_ok=True)
|
| 445 |
+
os.makedirs(save_dir + "/jsonl", exist_ok=True)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
for item in new_items:
|
| 449 |
+
lyric = item["gt_lyric"]
|
| 450 |
+
descriptions = item["descriptions"] if "descriptions" in item else None
|
| 451 |
+
pmt_wav = item['pmt_wav']
|
| 452 |
+
vocal_wav = item['vocal_wav']
|
| 453 |
+
bgm_wav = item['bgm_wav']
|
| 454 |
+
melody_is_wav = item['melody_is_wav']
|
| 455 |
+
|
| 456 |
+
generate_inp = {
|
| 457 |
+
'lyrics': [lyric.replace(" ", " ")],
|
| 458 |
+
'descriptions': [descriptions],
|
| 459 |
+
'melody_wavs': pmt_wav,
|
| 460 |
+
'vocal_wavs': vocal_wav,
|
| 461 |
+
'bgm_wavs': bgm_wav,
|
| 462 |
+
'melody_is_wav': melody_is_wav,
|
| 463 |
+
}
|
| 464 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 465 |
+
with torch.no_grad():
|
| 466 |
+
tokens = model.generate(**generate_inp, return_tokens=True)
|
| 467 |
+
if offload_audiolm:
|
| 468 |
+
offload_profiler.reset_empty_cache_mem_line()
|
| 469 |
+
item['tokens'] = tokens
|
| 470 |
+
if offload_audiolm:
|
| 471 |
+
offload_profiler.stop()
|
| 472 |
+
del offload_profiler
|
| 473 |
+
del audiolm_offload_param
|
| 474 |
+
del model
|
| 475 |
+
audiolm = audiolm.cpu()
|
| 476 |
+
del audiolm
|
| 477 |
+
del checkpoint
|
| 478 |
+
gc.collect()
|
| 479 |
+
torch.cuda.empty_cache()
|
| 480 |
+
|
| 481 |
+
seperate_tokenizer = builders.get_audio_tokenizer_model_cpu(cfg.audio_tokenizer_checkpoint_sep, cfg)
|
| 482 |
+
device = "cuda:0"
|
| 483 |
+
seperate_tokenizer.model.device = device
|
| 484 |
+
seperate_tokenizer.model.vae = seperate_tokenizer.model.vae.to(device)
|
| 485 |
+
seperate_tokenizer.model.model.device = torch.device(device)
|
| 486 |
+
seperate_tokenizer = seperate_tokenizer.eval()
|
| 487 |
+
|
| 488 |
+
# offload_wav_tokenizer_diffusion = True if 'offload' in cfg.keys() and 'wav_tokenizer_diffusion' in cfg.offload else False
|
| 489 |
+
offload_wav_tokenizer_diffusion = False
|
| 490 |
+
if offload_wav_tokenizer_diffusion:
|
| 491 |
+
sep_offload_param = OffloadParamParse.parse_config(seperate_tokenizer, cfg.offload.wav_tokenizer_diffusion)
|
| 492 |
+
sep_offload_param.show()
|
| 493 |
+
sep_offload_profiler = OffloadProfiler(device_index=0, **(sep_offload_param.init_param_dict()))
|
| 494 |
+
sep_offload_profiler.offload_layer(**(sep_offload_param.offload_layer_param_dict()))
|
| 495 |
+
sep_offload_profiler.clean_cache_wrapper(**(sep_offload_param.clean_cache_param_dict()))
|
| 496 |
+
else:
|
| 497 |
+
seperate_tokenizer.model.model = seperate_tokenizer.model.model.to(device)
|
| 498 |
+
|
| 499 |
+
model = CodecLM(name = "tmp",
|
| 500 |
+
lm = None,
|
| 501 |
+
audiotokenizer = None,
|
| 502 |
+
max_duration = max_duration,
|
| 503 |
+
seperate_tokenizer = seperate_tokenizer,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
for item in new_items:
|
| 507 |
+
with torch.no_grad():
|
| 508 |
+
if 'raw_pmt_wav' in item:
|
| 509 |
+
if gen_type == 'separate':
|
| 510 |
+
wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type='mixed')
|
| 511 |
+
wav_vocal = model.generate_audio(item['tokens'],chunked=True, gen_type='vocal')
|
| 512 |
+
wav_bgm = model.generate_audio(item['tokens'], chunked=True, gen_type='bgm')
|
| 513 |
+
elif gen_type == 'mixed':
|
| 514 |
+
wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type=gen_type)
|
| 515 |
+
else:
|
| 516 |
+
wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
|
| 517 |
+
del item['raw_pmt_wav']
|
| 518 |
+
del item['raw_vocal_wav']
|
| 519 |
+
del item['raw_bgm_wav']
|
| 520 |
+
else:
|
| 521 |
+
if gen_type == 'separate':
|
| 522 |
+
wav_vocal = model.generate_audio(item['tokens'], chunked=True, gen_type='vocal')
|
| 523 |
+
wav_bgm = model.generate_audio(item['tokens'], chunked=True, gen_type='bgm')
|
| 524 |
+
wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type='mixed')
|
| 525 |
+
else:
|
| 526 |
+
wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
|
| 527 |
+
if gen_type == 'separate':
|
| 528 |
+
torchaudio.save(item['wav_path'].replace('.flac', '_vocal.flac'), wav_vocal[0].cpu().float(), cfg.sample_rate)
|
| 529 |
+
torchaudio.save(item['wav_path'].replace('.flac', '_bgm.flac'), wav_bgm[0].cpu().float(), cfg.sample_rate)
|
| 530 |
+
torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
|
| 531 |
+
else:
|
| 532 |
+
torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
|
| 533 |
+
del item['tokens']
|
| 534 |
+
del item['pmt_wav']
|
| 535 |
+
del item['vocal_wav']
|
| 536 |
+
del item['bgm_wav']
|
| 537 |
+
del item['melody_is_wav']
|
| 538 |
+
if offload_wav_tokenizer_diffusion:
|
| 539 |
+
sep_offload_profiler.reset_empty_cache_mem_line()
|
| 540 |
+
|
| 541 |
+
if offload_wav_tokenizer_diffusion:
|
| 542 |
+
sep_offload_profiler.stop()
|
| 543 |
+
torch.cuda.empty_cache()
|
| 544 |
+
src_jsonl_name = os.path.split(input_jsonl)[-1]
|
| 545 |
+
with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
|
| 546 |
+
for item in new_items:
|
| 547 |
+
fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
if __name__ == "__main__":
|
| 551 |
+
torch.backends.cudnn.enabled = False
|
| 552 |
+
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
|
| 553 |
+
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
|
| 554 |
+
OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
|
| 555 |
+
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
|
| 556 |
+
np.random.seed(int(time.time()))
|
| 557 |
+
# Parse command line arguments
|
| 558 |
+
args = parse_args()
|
| 559 |
+
if torch.cuda.is_available():
|
| 560 |
+
device = torch.cuda.current_device()
|
| 561 |
+
reserved = torch.cuda.memory_reserved(device)
|
| 562 |
+
total = torch.cuda.get_device_properties(device).total_memory
|
| 563 |
+
res_mem = (total - reserved) / 1024 / 1024 / 1024
|
| 564 |
+
print(f"reserved memory: {res_mem}GB")
|
| 565 |
+
|
| 566 |
+
model_name = args.ckpt_path.split("/")[-1].lower().replace('-', '_')
|
| 567 |
+
assert model_name in ['songgeneration_base', 'songgeneration_base_new', 'songgeneration_base_full', 'songgeneration_large'], f'{model_name} is not supported, currently only songgeneration_base, songgeneration_base_new, songgeneration_base_full, songgeneration_large are supported. Please download correct files and rename the folder to the corresponding version name.'
|
| 568 |
+
if model_name == 'songgeneration_base' or model_name == 'songgeneration_base_new' or model_name == 'songgeneration_base_full':
|
| 569 |
+
if res_mem > 24 and not args.low_mem:
|
| 570 |
+
print("use generate")
|
| 571 |
+
generate(args)
|
| 572 |
+
else:
|
| 573 |
+
from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
|
| 574 |
+
print("use generate_lowmem")
|
| 575 |
+
generate_lowmem(args)
|
| 576 |
+
elif model_name == 'songgeneration_large':
|
| 577 |
+
if res_mem > 36 and not args.low_mem:
|
| 578 |
+
print("use generate")
|
| 579 |
+
generate(args)
|
| 580 |
+
else:
|
| 581 |
+
print("use generate_lowmem")
|
| 582 |
+
from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
|
| 583 |
+
generate_lowmem(args)
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
# elif model_name == 'songgeneration_base_full':
|
| 587 |
+
|
| 588 |
+
else:
|
| 589 |
+
print("CUDA is not available")
|
| 590 |
+
exit()
|
| 591 |
+
|
baseline_generate/mureka_o2/__pycache__/generate.cpython-311.pyc
ADDED
|
Binary file (19.5 kB). View file
|
|
|
baseline_generate/mureka_o2/generate.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Script for batch generating songs using Mureka API
|
| 4 |
+
Processes the first 100 songs in cleaned_data_no_desc.json file
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
import requests
|
| 11 |
+
from typing import Dict, List, Optional
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
# API Configuration
|
| 15 |
+
API_URL = "https://api.mureka.cn/v1/song/generate"
|
| 16 |
+
QUERY_API_URL = "https://api.mureka.cn/v1/song/query"
|
| 17 |
+
API_KEY_ENV = "MUREKA_API_KEY"
|
| 18 |
+
MODEL = "mureka-o2"
|
| 19 |
+
|
| 20 |
+
# Configuration Parameters
|
| 21 |
+
MAX_SONGS = 100
|
| 22 |
+
RETRY_TIMES = 3
|
| 23 |
+
RETRY_DELAY = 2 # seconds
|
| 24 |
+
REQUEST_DELAY = 60 # Delay between requests (seconds) - set to 60 seconds (1 minute) to avoid rate limiting
|
| 25 |
+
RATE_LIMIT_DELAY = 60 # Wait time when encountering 429 error (seconds)
|
| 26 |
+
QUERY_INTERVAL = 10 # Interval for querying task status (seconds)
|
| 27 |
+
MAX_QUERY_TIME = 3600 # Maximum query time (seconds), 1 hour
|
| 28 |
+
|
| 29 |
+
def load_songs(json_file: str, max_count: int = MAX_SONGS) -> List[Dict]:
|
| 30 |
+
"""Load song data from JSON file"""
|
| 31 |
+
with open(json_file, 'r', encoding='utf-8') as f:
|
| 32 |
+
data = json.load(f)
|
| 33 |
+
|
| 34 |
+
# Only take first max_count songs
|
| 35 |
+
return data[:max_count]
|
| 36 |
+
|
| 37 |
+
def is_song_processed(output_file: Path) -> bool:
|
| 38 |
+
"""
|
| 39 |
+
Check if song has been processed (including completed tasks)
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
output_file: Output file path
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
True if file exists and contains valid API response
|
| 46 |
+
"""
|
| 47 |
+
if not output_file.exists():
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
with open(output_file, 'r', encoding='utf-8') as f:
|
| 52 |
+
data = json.load(f)
|
| 53 |
+
# Check if contains api_response field
|
| 54 |
+
if 'api_response' in data and data['api_response']:
|
| 55 |
+
status = data['api_response'].get('status', '')
|
| 56 |
+
# If status is succeeded, failed, timeouted or cancelled, consider as processed
|
| 57 |
+
if status in ['succeeded', 'failed', 'timeouted', 'cancelled']:
|
| 58 |
+
return True
|
| 59 |
+
# If status is preparing, queued, running, streaming, reviewing, also consider as processed (task created)
|
| 60 |
+
if status in ['preparing', 'queued', 'running', 'streaming', 'reviewing']:
|
| 61 |
+
return True
|
| 62 |
+
except (json.JSONDecodeError, KeyError, IOError):
|
| 63 |
+
# File corrupted or format incorrect, consider as not processed
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
return False
|
| 67 |
+
|
| 68 |
+
def load_processed_song(output_file: Path) -> Optional[Dict]:
|
| 69 |
+
"""
|
| 70 |
+
Load processed song results from existing file
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
output_file: Output file path
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Processed song data, returns None if loading fails
|
| 77 |
+
"""
|
| 78 |
+
try:
|
| 79 |
+
with open(output_file, 'r', encoding='utf-8') as f:
|
| 80 |
+
data = json.load(f)
|
| 81 |
+
return data.get('api_response')
|
| 82 |
+
except (json.JSONDecodeError, KeyError, IOError):
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
def query_task_status(task_id: str, api_key: str) -> Optional[Dict]:
|
| 86 |
+
"""
|
| 87 |
+
Query task status
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
task_id: Task ID
|
| 91 |
+
api_key: API key
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
Task status data, returns None on failure
|
| 95 |
+
"""
|
| 96 |
+
headers = {
|
| 97 |
+
"Authorization": f"Bearer {api_key}"
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
url = f"{QUERY_API_URL}/{task_id}"
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
response = requests.get(url, headers=headers, timeout=30)
|
| 104 |
+
response.raise_for_status()
|
| 105 |
+
return response.json()
|
| 106 |
+
except requests.exceptions.RequestException as e:
|
| 107 |
+
print(f" Failed to query task status: {str(e)}")
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
def wait_for_task_completion(task_id: str, api_key: str) -> Optional[Dict]:
|
| 111 |
+
"""
|
| 112 |
+
Wait for task completion and return final result
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
task_id: Task ID
|
| 116 |
+
api_key: API key
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Complete data after task completion, returns None on failure
|
| 120 |
+
"""
|
| 121 |
+
start_time = time.time()
|
| 122 |
+
last_status = None
|
| 123 |
+
|
| 124 |
+
print(f" Waiting for task completion (Task ID: {task_id})...")
|
| 125 |
+
|
| 126 |
+
while time.time() - start_time < MAX_QUERY_TIME:
|
| 127 |
+
result = query_task_status(task_id, api_key)
|
| 128 |
+
|
| 129 |
+
if not result:
|
| 130 |
+
time.sleep(QUERY_INTERVAL)
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
status = result.get('status', '')
|
| 134 |
+
|
| 135 |
+
# If status changed, print new status
|
| 136 |
+
if status != last_status:
|
| 137 |
+
print(f" Status: {status}")
|
| 138 |
+
last_status = status
|
| 139 |
+
|
| 140 |
+
# Task completed (success or failure)
|
| 141 |
+
if status in ['succeeded', 'failed', 'timeouted', 'cancelled']:
|
| 142 |
+
if status == 'succeeded':
|
| 143 |
+
print(f" ✓ Task completed!")
|
| 144 |
+
if 'choices' in result and result['choices']:
|
| 145 |
+
print(f" Found {len(result['choices'])} generated songs")
|
| 146 |
+
else:
|
| 147 |
+
print(f" ✗ Task failed: {status}")
|
| 148 |
+
if 'failed_reason' in result:
|
| 149 |
+
print(f" Failure reason: {result['failed_reason']}")
|
| 150 |
+
return result
|
| 151 |
+
|
| 152 |
+
# Task still processing, continue waiting
|
| 153 |
+
time.sleep(QUERY_INTERVAL)
|
| 154 |
+
|
| 155 |
+
print(f" ⚠ Query timeout (exceeded {MAX_QUERY_TIME} seconds)")
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
def generate_song(lyrics: str, prompt: str, api_key: str) -> Optional[Dict]:
|
| 159 |
+
"""
|
| 160 |
+
Call API to generate a single song (serial processing, ensuring concurrency = 1)
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
lyrics: Lyrics content
|
| 164 |
+
prompt: Prompt (corresponds to description)
|
| 165 |
+
api_key: API key
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
API response data, returns None on failure
|
| 169 |
+
"""
|
| 170 |
+
headers = {
|
| 171 |
+
"Authorization": f"Bearer {api_key}",
|
| 172 |
+
"Content-Type": "application/json"
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
payload = {
|
| 176 |
+
"lyrics": lyrics,
|
| 177 |
+
"model": MODEL,
|
| 178 |
+
"prompt": prompt
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
for attempt in range(RETRY_TIMES):
|
| 182 |
+
try:
|
| 183 |
+
response = requests.post(API_URL, headers=headers, json=payload, timeout=300)
|
| 184 |
+
|
| 185 |
+
# Check if it's a 429 error (rate limit)
|
| 186 |
+
if response.status_code == 429:
|
| 187 |
+
# Try to get Retry-After time from response header
|
| 188 |
+
retry_after = response.headers.get('Retry-After')
|
| 189 |
+
if retry_after:
|
| 190 |
+
wait_time = int(retry_after)
|
| 191 |
+
else:
|
| 192 |
+
wait_time = RATE_LIMIT_DELAY
|
| 193 |
+
|
| 194 |
+
print(f" Attempt {attempt + 1}/{RETRY_TIMES} failed: 429 Too Many Requests")
|
| 195 |
+
print(f" Waiting {wait_time} seconds before retry...")
|
| 196 |
+
if attempt < RETRY_TIMES - 1:
|
| 197 |
+
time.sleep(wait_time)
|
| 198 |
+
continue
|
| 199 |
+
else:
|
| 200 |
+
print(f" All retries failed")
|
| 201 |
+
return None
|
| 202 |
+
|
| 203 |
+
response.raise_for_status()
|
| 204 |
+
return response.json()
|
| 205 |
+
|
| 206 |
+
except requests.exceptions.HTTPError as e:
|
| 207 |
+
if e.response and e.response.status_code == 429:
|
| 208 |
+
# 429 error already handled above
|
| 209 |
+
continue
|
| 210 |
+
print(f" Attempt {attempt + 1}/{RETRY_TIMES} failed: {str(e)}")
|
| 211 |
+
if attempt < RETRY_TIMES - 1:
|
| 212 |
+
time.sleep(RETRY_DELAY)
|
| 213 |
+
else:
|
| 214 |
+
print(f" All retries failed")
|
| 215 |
+
return None
|
| 216 |
+
except requests.exceptions.RequestException as e:
|
| 217 |
+
print(f" Attempt {attempt + 1}/{RETRY_TIMES} failed: {str(e)}")
|
| 218 |
+
if attempt < RETRY_TIMES - 1:
|
| 219 |
+
time.sleep(RETRY_DELAY)
|
| 220 |
+
else:
|
| 221 |
+
print(f" All retries failed")
|
| 222 |
+
return None
|
| 223 |
+
|
| 224 |
+
return None
|
| 225 |
+
|
| 226 |
+
def main():
|
| 227 |
+
"""Main function"""
|
| 228 |
+
# Check API key
|
| 229 |
+
api_key = os.getenv(API_KEY_ENV)
|
| 230 |
+
if not api_key:
|
| 231 |
+
print(f"Error: Please set environment variable {API_KEY_ENV}")
|
| 232 |
+
print(f"Example: export {API_KEY_ENV}=your_api_key")
|
| 233 |
+
return
|
| 234 |
+
|
| 235 |
+
# Load song data
|
| 236 |
+
json_file = "cleaned_data_no_desc.json"
|
| 237 |
+
if not os.path.exists(json_file):
|
| 238 |
+
print(f"Error: File not found {json_file}")
|
| 239 |
+
return
|
| 240 |
+
|
| 241 |
+
print(f"Loading song data from {json_file}...")
|
| 242 |
+
songs = load_songs(json_file, MAX_SONGS)
|
| 243 |
+
print(f"Loaded {len(songs)} songs")
|
| 244 |
+
|
| 245 |
+
# Create output directory
|
| 246 |
+
output_dir = Path("generated_songs")
|
| 247 |
+
output_dir.mkdir(exist_ok=True)
|
| 248 |
+
|
| 249 |
+
# List to save results
|
| 250 |
+
results = []
|
| 251 |
+
|
| 252 |
+
# Process each song (serial processing, ensuring concurrency = 1)
|
| 253 |
+
for idx, song in enumerate(songs, 1):
|
| 254 |
+
print(f"\n[{idx}/{len(songs)}] Processing song...")
|
| 255 |
+
print(f" Description: {song.get('description', 'N/A')[:50]}...")
|
| 256 |
+
|
| 257 |
+
# Check output file path
|
| 258 |
+
output_file = output_dir / f"song_{idx:03d}.json"
|
| 259 |
+
|
| 260 |
+
# Check if already processed
|
| 261 |
+
if is_song_processed(output_file):
|
| 262 |
+
print(f" ⊙ Already processed, skipping")
|
| 263 |
+
# Load processed results from file
|
| 264 |
+
existing_result = load_processed_song(output_file)
|
| 265 |
+
results.append({
|
| 266 |
+
"index": idx,
|
| 267 |
+
"status": "already_processed",
|
| 268 |
+
"output_file": str(output_file)
|
| 269 |
+
})
|
| 270 |
+
# Already processed songs don't need delay
|
| 271 |
+
continue
|
| 272 |
+
|
| 273 |
+
lyrics = song.get('lyrics', '')
|
| 274 |
+
description = song.get('description', '')
|
| 275 |
+
|
| 276 |
+
if not lyrics or not description:
|
| 277 |
+
print(f" Skipping: missing lyrics or description")
|
| 278 |
+
results.append({
|
| 279 |
+
"index": idx,
|
| 280 |
+
"status": "skipped",
|
| 281 |
+
"reason": "missing data"
|
| 282 |
+
})
|
| 283 |
+
continue
|
| 284 |
+
|
| 285 |
+
# Call API (serial execution, ensuring concurrency = 1)
|
| 286 |
+
result = generate_song(lyrics, description, api_key)
|
| 287 |
+
|
| 288 |
+
if result:
|
| 289 |
+
task_id = result.get('id')
|
| 290 |
+
initial_status = result.get('status', '')
|
| 291 |
+
print(f" ✓ Task created (ID: {task_id}, Status: {initial_status})")
|
| 292 |
+
|
| 293 |
+
# If task status is not final, wait for task completion
|
| 294 |
+
if initial_status not in ['succeeded', 'failed', 'timeouted', 'cancelled']:
|
| 295 |
+
final_result = wait_for_task_completion(task_id, api_key)
|
| 296 |
+
if final_result:
|
| 297 |
+
result = final_result
|
| 298 |
+
else:
|
| 299 |
+
# Query timeout, use initial result
|
| 300 |
+
print(f" ⚠ Using initial result (query timeout)")
|
| 301 |
+
|
| 302 |
+
# Save single result (including final status and choices)
|
| 303 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 304 |
+
json.dump({
|
| 305 |
+
"index": idx,
|
| 306 |
+
"original_data": song,
|
| 307 |
+
"api_response": result,
|
| 308 |
+
"task_id": task_id
|
| 309 |
+
}, f, ensure_ascii=False, indent=2)
|
| 310 |
+
|
| 311 |
+
# Check if successfully completed
|
| 312 |
+
final_status = result.get('status', '')
|
| 313 |
+
if final_status == 'succeeded':
|
| 314 |
+
results.append({
|
| 315 |
+
"index": idx,
|
| 316 |
+
"status": "success",
|
| 317 |
+
"output_file": str(output_file),
|
| 318 |
+
"task_id": task_id,
|
| 319 |
+
"has_audio": 'choices' in result and len(result.get('choices', [])) > 0
|
| 320 |
+
})
|
| 321 |
+
elif final_status in ['failed', 'timeouted', 'cancelled']:
|
| 322 |
+
results.append({
|
| 323 |
+
"index": idx,
|
| 324 |
+
"status": "failed",
|
| 325 |
+
"output_file": str(output_file),
|
| 326 |
+
"task_id": task_id,
|
| 327 |
+
"failed_reason": result.get('failed_reason', final_status)
|
| 328 |
+
})
|
| 329 |
+
else:
|
| 330 |
+
# Task still processing
|
| 331 |
+
results.append({
|
| 332 |
+
"index": idx,
|
| 333 |
+
"status": "processing",
|
| 334 |
+
"output_file": str(output_file),
|
| 335 |
+
"task_id": task_id,
|
| 336 |
+
"current_status": final_status
|
| 337 |
+
})
|
| 338 |
+
else:
|
| 339 |
+
print(f" ✗ Generation failed")
|
| 340 |
+
# Save failure information, including original data
|
| 341 |
+
error_file = output_dir / f"song_{idx:03d}_error.json"
|
| 342 |
+
with open(error_file, 'w', encoding='utf-8') as f:
|
| 343 |
+
json.dump({
|
| 344 |
+
"index": idx,
|
| 345 |
+
"original_data": song,
|
| 346 |
+
"error": "API call failed, generate_song returned None",
|
| 347 |
+
"timestamp": time.time()
|
| 348 |
+
}, f, ensure_ascii=False, indent=2)
|
| 349 |
+
|
| 350 |
+
results.append({
|
| 351 |
+
"index": idx,
|
| 352 |
+
"status": "failed",
|
| 353 |
+
"error_file": str(error_file),
|
| 354 |
+
"reason": "API call failed"
|
| 355 |
+
})
|
| 356 |
+
|
| 357 |
+
# Delay between requests to avoid rate limiting (ensuring concurrency = 1)
|
| 358 |
+
if idx < len(songs):
|
| 359 |
+
print(f" Waiting {REQUEST_DELAY} seconds before processing next song...")
|
| 360 |
+
time.sleep(REQUEST_DELAY)
|
| 361 |
+
|
| 362 |
+
# Save summary results
|
| 363 |
+
summary_file = output_dir / "summary.json"
|
| 364 |
+
with open(summary_file, 'w', encoding='utf-8') as f:
|
| 365 |
+
json.dump({
|
| 366 |
+
"total": len(songs),
|
| 367 |
+
"success": sum(1 for r in results if r.get("status") == "success"),
|
| 368 |
+
"processing": sum(1 for r in results if r.get("status") == "processing"),
|
| 369 |
+
"already_processed": sum(1 for r in results if r.get("status") == "already_processed"),
|
| 370 |
+
"failed": sum(1 for r in results if r.get("status") == "failed"),
|
| 371 |
+
"skipped": sum(1 for r in results if r.get("status") == "skipped"),
|
| 372 |
+
"results": results
|
| 373 |
+
}, f, ensure_ascii=False, indent=2)
|
| 374 |
+
|
| 375 |
+
# Print statistics
|
| 376 |
+
print(f"\n{'='*50}")
|
| 377 |
+
print(f"Processing complete!")
|
| 378 |
+
print(f"Total: {len(songs)} songs")
|
| 379 |
+
print(f"Successfully completed: {sum(1 for r in results if r.get('status') == 'success')} songs")
|
| 380 |
+
print(f"Processing: {sum(1 for r in results if r.get('status') == 'processing')} songs")
|
| 381 |
+
print(f"Already processed: {sum(1 for r in results if r.get('status') == 'already_processed')} songs")
|
| 382 |
+
print(f"Failed: {sum(1 for r in results if r.get('status') == 'failed')} songs")
|
| 383 |
+
print(f"Skipped: {sum(1 for r in results if r.get('status') == 'skipped')} songs")
|
| 384 |
+
print(f"Results saved in: {output_dir}/")
|
| 385 |
+
print(f"Summary file: {summary_file}")
|
| 386 |
+
print(f"\nTip: If tasks are still processing, you can rerun the script later to check status")
|
| 387 |
+
|
| 388 |
+
if __name__ == "__main__":
|
| 389 |
+
main()
|
| 390 |
+
|
baseline_generate/suno/__pycache__/suno_4_5.cpython-311.pyc
ADDED
|
Binary file (41.6 kB). View file
|
|
|
baseline_generate/suno/__pycache__/suno_5.cpython-311.pyc
ADDED
|
Binary file (41.7 kB). View file
|
|
|
baseline_generate/suno/config.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration file example
|
| 3 |
+
Copy this file as config.py and fill in your actual configuration
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# ============== API Configuration ==============
|
| 7 |
+
|
| 8 |
+
# Suno API key (obtain from https://sunoapi.org)
|
| 9 |
+
SUNO_API_KEY = ""
|
| 10 |
+
|
| 11 |
+
# API base URL (usually no need to modify)
|
| 12 |
+
SUNO_API_BASE_URL = "https://api.sunoapi.org"
|
| 13 |
+
|
| 14 |
+
# ============== Generation Configuration ==============
|
| 15 |
+
|
| 16 |
+
# Default model version
|
| 17 |
+
DEFAULT_MODEL_VERSION = "V5" # Options: V3_5, V4, V4_5, V4_5PLUS, V5
|
| 18 |
+
|
| 19 |
+
# Whether to enable custom mode
|
| 20 |
+
DEFAULT_CUSTOM_MODE = True
|
| 21 |
+
|
| 22 |
+
# Whether to generate instrumental by default
|
| 23 |
+
DEFAULT_INSTRUMENTAL = False
|
| 24 |
+
|
| 25 |
+
# ============== Task Configuration ==============
|
| 26 |
+
|
| 27 |
+
# Maximum wait time (seconds)
|
| 28 |
+
MAX_WAIT_TIME = 300
|
| 29 |
+
|
| 30 |
+
# Check interval (seconds)
|
| 31 |
+
CHECK_INTERVAL = 10
|
| 32 |
+
|
| 33 |
+
# Retry count
|
| 34 |
+
MAX_RETRIES = 3
|
| 35 |
+
|
| 36 |
+
# ============== File Configuration ==============
|
| 37 |
+
|
| 38 |
+
# Music file save directory
|
| 39 |
+
OUTPUT_DIRECTORY = "./generated_music"
|
| 40 |
+
|
| 41 |
+
# Audio format
|
| 42 |
+
AUDIO_FORMAT = "mp3" # Options: mp3, wav
|
| 43 |
+
|
| 44 |
+
# ============== Batch Generation Configuration ==============
|
| 45 |
+
|
| 46 |
+
# Concurrency for batch generation
|
| 47 |
+
BATCH_CONCURRENCY = 5
|
| 48 |
+
|
| 49 |
+
# Batch generation delay (seconds, to avoid rate limiting)
|
| 50 |
+
BATCH_DELAY = 2
|
| 51 |
+
|
| 52 |
+
# ============== Logging Configuration ==============
|
| 53 |
+
|
| 54 |
+
# Log level
|
| 55 |
+
LOG_LEVEL = "INFO" # Options: DEBUG, INFO, WARNING, ERROR
|
| 56 |
+
|
| 57 |
+
# Log file path
|
| 58 |
+
LOG_FILE = "./suno_api.log"
|
| 59 |
+
|
| 60 |
+
# Whether to output to console
|
| 61 |
+
LOG_TO_CONSOLE = True
|
| 62 |
+
|
| 63 |
+
# ============== Webhook Configuration ==============
|
| 64 |
+
|
| 65 |
+
# Webhook callback URL (optional)
|
| 66 |
+
WEBHOOK_URL = None # Example: "https://your-domain.com/webhook"
|
| 67 |
+
|
| 68 |
+
# Webhook secret (for verifying callback requests)
|
| 69 |
+
WEBHOOK_SECRET = None
|
| 70 |
+
|
baseline_generate/suno/suno_4_5.py
ADDED
|
@@ -0,0 +1,766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Suno API Batch Generation - V4.5 Special Edition
|
| 4 |
+
Supported models: V4_5 (default), V4_5PLUS, V4_5ALL
|
| 5 |
+
"""
|
| 6 |
+
import json
|
| 7 |
+
import time
|
| 8 |
+
import requests
|
| 9 |
+
import os
|
| 10 |
+
import logging
|
| 11 |
+
import csv
|
| 12 |
+
from requests.adapters import HTTPAdapter
|
| 13 |
+
from urllib3.util.retry import Retry
|
| 14 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 15 |
+
from collections import deque
|
| 16 |
+
from threading import Lock, Semaphore
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
import sys
|
| 19 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 20 |
+
from config import SUNO_API_KEY
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Configure logging
|
| 24 |
+
def setup_logging(output_dir):
|
| 25 |
+
log_file = os.path.join(output_dir, f"run_log_v4_5_{time.strftime('%Y%m%d_%H%M%S')}.txt")
|
| 26 |
+
|
| 27 |
+
# Create logger
|
| 28 |
+
logger = logging.getLogger('SunoBatchV4_5')
|
| 29 |
+
logger.setLevel(logging.INFO)
|
| 30 |
+
|
| 31 |
+
# Clear old handlers
|
| 32 |
+
if logger.hasHandlers():
|
| 33 |
+
logger.handlers.clear()
|
| 34 |
+
|
| 35 |
+
# File Handler
|
| 36 |
+
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
| 37 |
+
file_handler.setFormatter(logging.Formatter('%(message)s'))
|
| 38 |
+
logger.addHandler(file_handler)
|
| 39 |
+
|
| 40 |
+
# Console Handler
|
| 41 |
+
console_handler = logging.StreamHandler()
|
| 42 |
+
console_handler.setFormatter(logging.Formatter('%(message)s'))
|
| 43 |
+
logger.addHandler(console_handler)
|
| 44 |
+
|
| 45 |
+
return logger, log_file
|
| 46 |
+
|
| 47 |
+
# Global logger
|
| 48 |
+
logger = logging.getLogger('SunoBatchV4_5')
|
| 49 |
+
|
| 50 |
+
# Replace print with logger.info
|
| 51 |
+
def print_log(msg):
|
| 52 |
+
logger.info(msg)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class SunoAPI:
|
| 56 |
+
"""Simplified Suno API client"""
|
| 57 |
+
|
| 58 |
+
def __init__(self, api_key):
|
| 59 |
+
self.api_key = api_key
|
| 60 |
+
self.base_url = 'https://api.sunoapi.org/api/v1'
|
| 61 |
+
self.headers = {
|
| 62 |
+
'Authorization': f'Bearer {api_key}',
|
| 63 |
+
'Content-Type': 'application/json'
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
# Configure retry strategy
|
| 67 |
+
self.session = requests.Session()
|
| 68 |
+
retry_strategy = Retry(
|
| 69 |
+
total=5, # Maximum retry count
|
| 70 |
+
backoff_factor=1, # Retry interval (1s, 2s, 4s, 8s...)
|
| 71 |
+
status_forcelist=[500, 502, 503, 504], # Status codes that need retry
|
| 72 |
+
allowed_methods=["HEAD", "GET", "POST", "OPTIONS"] # Allowed retry methods
|
| 73 |
+
)
|
| 74 |
+
adapter = HTTPAdapter(max_retries=retry_strategy)
|
| 75 |
+
self.session.mount("https://", adapter)
|
| 76 |
+
self.session.mount("http://", adapter)
|
| 77 |
+
|
| 78 |
+
def generate_music(self, prompt, model='V4_5', vocalGender=None, **options):
|
| 79 |
+
"""Generate music"""
|
| 80 |
+
payload = {
|
| 81 |
+
'prompt': prompt,
|
| 82 |
+
'model': model,
|
| 83 |
+
'callBackUrl': 'https://example.com/callback',
|
| 84 |
+
**options
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
if vocalGender:
|
| 88 |
+
payload['vocalGender'] = vocalGender
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
response = self.session.post(
|
| 92 |
+
f'{self.base_url}/generate',
|
| 93 |
+
headers=self.headers,
|
| 94 |
+
json=payload,
|
| 95 |
+
timeout=30
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Check HTTP errors
|
| 99 |
+
response.raise_for_status()
|
| 100 |
+
|
| 101 |
+
# Try to parse JSON
|
| 102 |
+
try:
|
| 103 |
+
result = response.json()
|
| 104 |
+
except json.JSONDecodeError:
|
| 105 |
+
raise Exception(f"API returned non-JSON response: {response.text[:200]}")
|
| 106 |
+
|
| 107 |
+
if result.get('code') != 200:
|
| 108 |
+
raise Exception(f"Generation failed: {result.get('msg', result)}")
|
| 109 |
+
|
| 110 |
+
return result['data']['taskId']
|
| 111 |
+
|
| 112 |
+
except requests.exceptions.RequestException as e:
|
| 113 |
+
raise Exception(f"Request exception: {str(e)}")
|
| 114 |
+
|
| 115 |
+
def get_task_status(self, task_id):
|
| 116 |
+
"""Get task status"""
|
| 117 |
+
try:
|
| 118 |
+
response = self.session.get(
|
| 119 |
+
f'{self.base_url}/generate/record-info?taskId={task_id}',
|
| 120 |
+
headers={'Authorization': f'Bearer {self.api_key}'},
|
| 121 |
+
timeout=30
|
| 122 |
+
)
|
| 123 |
+
response.raise_for_status()
|
| 124 |
+
return response.json().get('data', {})
|
| 125 |
+
except Exception as e:
|
| 126 |
+
# Status query failure should not crash the program, return empty dict or throw specific exception
|
| 127 |
+
# print_log(f"Failed to get status: {e}")
|
| 128 |
+
raise e
|
| 129 |
+
|
| 130 |
+
def get_timestamped_lyrics(self, task_id, audio_id):
|
| 131 |
+
"""Get timestamped lyrics"""
|
| 132 |
+
payload = {
|
| 133 |
+
'taskId': task_id,
|
| 134 |
+
'audioId': audio_id
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
response = self.session.post(
|
| 139 |
+
f'{self.base_url}/generate/get-timestamped-lyrics',
|
| 140 |
+
headers=self.headers,
|
| 141 |
+
json=payload,
|
| 142 |
+
timeout=30
|
| 143 |
+
)
|
| 144 |
+
response.raise_for_status()
|
| 145 |
+
return response.json()
|
| 146 |
+
except Exception:
|
| 147 |
+
return {} # Lyrics retrieval failure is non-fatal error
|
| 148 |
+
|
| 149 |
+
def wait_for_completion(self, task_id, max_wait_time=600, check_interval=5):
|
| 150 |
+
"""Wait for task completion, return result and polling statistics"""
|
| 151 |
+
start_time = time.time()
|
| 152 |
+
poll_count = 0
|
| 153 |
+
total_poll_time = 0
|
| 154 |
+
|
| 155 |
+
while time.time() - start_time < max_wait_time:
|
| 156 |
+
try:
|
| 157 |
+
poll_start = time.time()
|
| 158 |
+
status = self.get_task_status(task_id)
|
| 159 |
+
poll_time = time.time() - poll_start
|
| 160 |
+
poll_count += 1
|
| 161 |
+
total_poll_time += poll_time
|
| 162 |
+
|
| 163 |
+
current_status = status.get('status')
|
| 164 |
+
|
| 165 |
+
if current_status == 'SUCCESS':
|
| 166 |
+
return {
|
| 167 |
+
'result': status.get('response'),
|
| 168 |
+
'wait_time': time.time() - start_time,
|
| 169 |
+
'poll_count': poll_count,
|
| 170 |
+
'avg_poll_time': total_poll_time / poll_count if poll_count > 0 else 0
|
| 171 |
+
}
|
| 172 |
+
elif current_status == 'FAILED':
|
| 173 |
+
raise Exception(f"Task failed: {status.get('errorMessage')}")
|
| 174 |
+
|
| 175 |
+
time.sleep(check_interval)
|
| 176 |
+
except Exception as e:
|
| 177 |
+
if time.time() - start_time >= max_wait_time:
|
| 178 |
+
raise
|
| 179 |
+
time.sleep(check_interval)
|
| 180 |
+
|
| 181 |
+
raise Exception('Task timeout')
|
| 182 |
+
|
| 183 |
+
def download_file(self, url, save_path):
|
| 184 |
+
"""Download file to local, return download statistics"""
|
| 185 |
+
try:
|
| 186 |
+
start_time = time.time()
|
| 187 |
+
downloaded_bytes = 0
|
| 188 |
+
|
| 189 |
+
# Use session to download
|
| 190 |
+
with self.session.get(url, stream=True, timeout=60) as r:
|
| 191 |
+
r.raise_for_status()
|
| 192 |
+
with open(save_path, 'wb') as f:
|
| 193 |
+
for chunk in r.iter_content(chunk_size=8192):
|
| 194 |
+
f.write(chunk)
|
| 195 |
+
downloaded_bytes += len(chunk)
|
| 196 |
+
|
| 197 |
+
download_time = time.time() - start_time
|
| 198 |
+
return {
|
| 199 |
+
'success': True,
|
| 200 |
+
'bytes': downloaded_bytes,
|
| 201 |
+
'time': download_time,
|
| 202 |
+
'speed': downloaded_bytes / download_time if download_time > 0 else 0
|
| 203 |
+
}
|
| 204 |
+
except Exception as e:
|
| 205 |
+
print_log(f"Download failed {url}: {e}")
|
| 206 |
+
return {'success': False, 'error': str(e)}
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# Result record lock
|
| 210 |
+
result_lock = Lock()
|
| 211 |
+
|
| 212 |
+
def save_result_record(output_dir, record):
|
| 213 |
+
"""Save single result to CSV in real-time"""
|
| 214 |
+
file_path = os.path.join(output_dir, "generation_results.csv")
|
| 215 |
+
file_exists = os.path.isfile(file_path)
|
| 216 |
+
|
| 217 |
+
# Only record key information
|
| 218 |
+
row = {
|
| 219 |
+
'song_id': record.get('song_id'),
|
| 220 |
+
'task_id': record.get('task_id'),
|
| 221 |
+
'status': 'SUCCESS' if record.get('success') else 'FAILED',
|
| 222 |
+
'error': record.get('error', ''),
|
| 223 |
+
'submit_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(record.get('submit_time', 0))),
|
| 224 |
+
'total_time': f"{record.get('total_time', 0):.1f}",
|
| 225 |
+
'tracks_count': record.get('tracks_count', 0)
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
with result_lock:
|
| 229 |
+
with open(file_path, 'a', newline='', encoding='utf-8') as f:
|
| 230 |
+
writer = csv.DictWriter(f, fieldnames=['song_id', 'task_id', 'status', 'error', 'submit_time', 'total_time', 'tracks_count'])
|
| 231 |
+
if not file_exists:
|
| 232 |
+
writer.writeheader()
|
| 233 |
+
writer.writerow(row)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class ImprovedRateLimiter:
|
| 237 |
+
"""Improved rate limiter (with statistics)
|
| 238 |
+
|
| 239 |
+
Precise control: maximum 8 requests per 10 seconds
|
| 240 |
+
Uses sliding window algorithm to ensure no more than 8 requests in any 10-second time window
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
def __init__(self, max_requests=5, time_window=10):
|
| 244 |
+
self.max_requests = max_requests
|
| 245 |
+
self.time_window = time_window
|
| 246 |
+
self.request_times = deque()
|
| 247 |
+
self.lock = Lock()
|
| 248 |
+
self.semaphore = Semaphore(max_requests)
|
| 249 |
+
|
| 250 |
+
# Statistics
|
| 251 |
+
self.total_wait_time = 0
|
| 252 |
+
self.wait_count = 0
|
| 253 |
+
self.total_requests = 0
|
| 254 |
+
|
| 255 |
+
def acquire(self):
|
| 256 |
+
"""Acquire request permission"""
|
| 257 |
+
with self.lock:
|
| 258 |
+
now = time.time()
|
| 259 |
+
|
| 260 |
+
# Clean expired request records
|
| 261 |
+
while self.request_times and now - self.request_times[0] >= self.time_window:
|
| 262 |
+
self.request_times.popleft()
|
| 263 |
+
|
| 264 |
+
# If limit reached, calculate wait time needed
|
| 265 |
+
wait_time = 0
|
| 266 |
+
if len(self.request_times) >= self.max_requests:
|
| 267 |
+
oldest_request = self.request_times[0]
|
| 268 |
+
wait_time = self.time_window - (now - oldest_request) + 0.05 # Add buffer
|
| 269 |
+
|
| 270 |
+
if wait_time > 0:
|
| 271 |
+
print_log(f" [Rate Limit] Waiting {wait_time:.2f} seconds...")
|
| 272 |
+
time.sleep(wait_time)
|
| 273 |
+
|
| 274 |
+
# Record wait time
|
| 275 |
+
self.total_wait_time += wait_time
|
| 276 |
+
self.wait_count += 1
|
| 277 |
+
|
| 278 |
+
# Re-clean
|
| 279 |
+
now = time.time()
|
| 280 |
+
while self.request_times and now - self.request_times[0] >= self.time_window:
|
| 281 |
+
self.request_times.popleft()
|
| 282 |
+
|
| 283 |
+
# Record this request time
|
| 284 |
+
self.request_times.append(time.time())
|
| 285 |
+
self.total_requests += 1
|
| 286 |
+
|
| 287 |
+
def get_current_rate(self):
|
| 288 |
+
"""Get current rate (number of requests in last 10 seconds)"""
|
| 289 |
+
with self.lock:
|
| 290 |
+
now = time.time()
|
| 291 |
+
while self.request_times and now - self.request_times[0] >= self.time_window:
|
| 292 |
+
self.request_times.popleft()
|
| 293 |
+
return len(self.request_times)
|
| 294 |
+
|
| 295 |
+
def get_stats(self):
|
| 296 |
+
"""Get statistics"""
|
| 297 |
+
with self.lock:
|
| 298 |
+
return {
|
| 299 |
+
'total_requests': self.total_requests,
|
| 300 |
+
'total_wait_time': self.total_wait_time,
|
| 301 |
+
'wait_count': self.wait_count,
|
| 302 |
+
'avg_wait_time': self.total_wait_time / self.wait_count if self.wait_count > 0 else 0
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# Global rate limiter (5 requests per 10 seconds)
|
| 307 |
+
rate_limiter = ImprovedRateLimiter(max_requests=5, time_window=10)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def submit_generation_task(api, song_index, data):
|
| 311 |
+
"""Phase 1: Submit generation task (rate limited)"""
|
| 312 |
+
# Use sunov4_5_000001 format
|
| 313 |
+
song_id = data.get("id", f"sunov4_5_{song_index:06d}")
|
| 314 |
+
|
| 315 |
+
try:
|
| 316 |
+
description = data.get("description", "")
|
| 317 |
+
lyrics = data.get("lyrics", "")
|
| 318 |
+
vocal_gender = data.get("vocalGender")
|
| 319 |
+
|
| 320 |
+
print_log(f"[Song {song_id}] Submitting task... (current rate: {rate_limiter.get_current_rate()}/5)")
|
| 321 |
+
|
| 322 |
+
# Record request start time
|
| 323 |
+
request_start = time.time()
|
| 324 |
+
|
| 325 |
+
# Rate limiting
|
| 326 |
+
rate_limiter.acquire()
|
| 327 |
+
|
| 328 |
+
# Submit task
|
| 329 |
+
submit_start = time.time()
|
| 330 |
+
task_id = api.generate_music(
|
| 331 |
+
prompt=lyrics,
|
| 332 |
+
style=description,
|
| 333 |
+
title=f"Song_{song_id}",
|
| 334 |
+
model='V4_5', # Explicitly specify V4.5 model
|
| 335 |
+
customMode=True,
|
| 336 |
+
instrumental=False,
|
| 337 |
+
vocalGender=vocal_gender
|
| 338 |
+
)
|
| 339 |
+
request_time = time.time() - submit_start
|
| 340 |
+
|
| 341 |
+
print_log(f"[Song {song_id}] ✓ Task submitted, ID: {task_id}")
|
| 342 |
+
|
| 343 |
+
return {
|
| 344 |
+
'song_id': song_id,
|
| 345 |
+
'song_index': song_index,
|
| 346 |
+
'task_id': task_id,
|
| 347 |
+
'data': data,
|
| 348 |
+
'submit_time': time.time(),
|
| 349 |
+
'request_time': request_time,
|
| 350 |
+
'success': True
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
except Exception as e:
|
| 354 |
+
print_log(f"[Song {song_id}] ✗ Submission failed: {e}")
|
| 355 |
+
# If submission fails, also record it (even though not at download stage yet)
|
| 356 |
+
return {
|
| 357 |
+
'song_id': song_id,
|
| 358 |
+
'song_index': song_index,
|
| 359 |
+
'success': False,
|
| 360 |
+
'error': str(e)
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def wait_and_download_result(api, task_info, output_dir):
|
| 365 |
+
"""Phase 2: Wait for result and download (not rate limited)"""
|
| 366 |
+
if not task_info['success']:
|
| 367 |
+
return task_info
|
| 368 |
+
|
| 369 |
+
song_id = task_info['song_id']
|
| 370 |
+
song_index = task_info['song_index']
|
| 371 |
+
task_id = task_info['task_id']
|
| 372 |
+
data = task_info['data']
|
| 373 |
+
start_time = task_info['submit_time']
|
| 374 |
+
|
| 375 |
+
try:
|
| 376 |
+
original_lyrics = data.get("original_lyrics", data.get("lyrics", ""))
|
| 377 |
+
lyrics = data.get("lyrics", "")
|
| 378 |
+
description = data.get("description", "")
|
| 379 |
+
|
| 380 |
+
print_log(f"[Song {song_id}] Waiting for generation to complete...")
|
| 381 |
+
|
| 382 |
+
# Wait for completion (returns detailed statistics)
|
| 383 |
+
wait_result = api.wait_for_completion(task_id, max_wait_time=600, check_interval=8)
|
| 384 |
+
result = wait_result['result']
|
| 385 |
+
|
| 386 |
+
# Process returned result
|
| 387 |
+
tracks = []
|
| 388 |
+
if isinstance(result, dict):
|
| 389 |
+
if 'data' in result:
|
| 390 |
+
tracks = result['data']
|
| 391 |
+
elif 'sunoData' in result:
|
| 392 |
+
tracks = result['sunoData']
|
| 393 |
+
else:
|
| 394 |
+
for key, value in result.items():
|
| 395 |
+
if isinstance(value, list) and len(value) > 0 and 'audioUrl' in value[0]:
|
| 396 |
+
tracks = value
|
| 397 |
+
break
|
| 398 |
+
|
| 399 |
+
if not tracks:
|
| 400 |
+
raise Exception("Audio track data not found")
|
| 401 |
+
|
| 402 |
+
# Download phase statistics
|
| 403 |
+
download_start = time.time()
|
| 404 |
+
downloaded_files = []
|
| 405 |
+
total_download_bytes = 0
|
| 406 |
+
download_count = 0
|
| 407 |
+
|
| 408 |
+
# Process each track
|
| 409 |
+
for track_idx, track in enumerate(tracks):
|
| 410 |
+
audio_url = track.get('audioUrl') or track.get('audio_url')
|
| 411 |
+
audio_id = track.get('id')
|
| 412 |
+
|
| 413 |
+
base_filename = f"{song_id}_{track_idx}"
|
| 414 |
+
audio_path = os.path.join(output_dir, f"{base_filename}.mp3")
|
| 415 |
+
lyrics_path = os.path.join(output_dir, f"{base_filename}_lyrics.json")
|
| 416 |
+
|
| 417 |
+
# Download audio
|
| 418 |
+
if audio_url:
|
| 419 |
+
download_result = api.download_file(audio_url, audio_path)
|
| 420 |
+
if download_result['success']:
|
| 421 |
+
downloaded_files.append(audio_path)
|
| 422 |
+
total_download_bytes += download_result['bytes']
|
| 423 |
+
download_count += 1
|
| 424 |
+
|
| 425 |
+
# Get timestamped lyrics
|
| 426 |
+
timestamped_lyrics_data = None
|
| 427 |
+
if audio_id:
|
| 428 |
+
try:
|
| 429 |
+
lyrics_response = api.get_timestamped_lyrics(task_id, audio_id)
|
| 430 |
+
if lyrics_response.get('code') == 200:
|
| 431 |
+
timestamped_lyrics_data = lyrics_response.get('data')
|
| 432 |
+
except Exception as e:
|
| 433 |
+
print_log(f"[Song {song_id}] Track {track_idx+1}: Failed to get lyrics: {e}")
|
| 434 |
+
|
| 435 |
+
# Save lyrics and metadata
|
| 436 |
+
lyrics_content = {
|
| 437 |
+
"song_id": song_id,
|
| 438 |
+
"song_index": song_index,
|
| 439 |
+
"track_index": track_idx,
|
| 440 |
+
"original_lyrics": original_lyrics,
|
| 441 |
+
"cleaned_lyrics": lyrics,
|
| 442 |
+
"timestamped_lyrics": timestamped_lyrics_data,
|
| 443 |
+
"style": description,
|
| 444 |
+
"full_track_data": track
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
with open(lyrics_path, 'w', encoding='utf-8') as f:
|
| 448 |
+
json.dump(lyrics_content, f, ensure_ascii=False, indent=2)
|
| 449 |
+
downloaded_files.append(lyrics_path)
|
| 450 |
+
|
| 451 |
+
download_time = time.time() - download_start
|
| 452 |
+
total_time = time.time() - start_time
|
| 453 |
+
|
| 454 |
+
print_log(f"[Song {song_id}] ✓ Complete! {len(tracks)} tracks, took {total_time:.1f} seconds")
|
| 455 |
+
|
| 456 |
+
final_result = {
|
| 457 |
+
'song_id': song_id,
|
| 458 |
+
'song_index': song_index,
|
| 459 |
+
'task_id': task_id,
|
| 460 |
+
'success': True,
|
| 461 |
+
'tracks_count': len(tracks),
|
| 462 |
+
'files': downloaded_files,
|
| 463 |
+
'total_time': total_time,
|
| 464 |
+
'submit_time': start_time,
|
| 465 |
+
'wait_time': wait_result['wait_time'],
|
| 466 |
+
'poll_count': wait_result['poll_count'],
|
| 467 |
+
'avg_poll_time': wait_result['avg_poll_time'],
|
| 468 |
+
'download_time': download_time,
|
| 469 |
+
'download_bytes': total_download_bytes,
|
| 470 |
+
'download_count': download_count,
|
| 471 |
+
'avg_download_speed': total_download_bytes / download_time if download_time > 0 else 0
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
# Save result in real-time
|
| 475 |
+
save_result_record(output_dir, final_result)
|
| 476 |
+
return final_result
|
| 477 |
+
|
| 478 |
+
except Exception as e:
|
| 479 |
+
total_time = time.time() - start_time
|
| 480 |
+
print_log(f"[Song {song_id}] ✗ Processing failed: {e} (took {total_time:.1f} seconds)")
|
| 481 |
+
|
| 482 |
+
error_result = {
|
| 483 |
+
'song_id': song_id,
|
| 484 |
+
'song_index': song_index,
|
| 485 |
+
'task_id': task_id,
|
| 486 |
+
'success': False,
|
| 487 |
+
'error': str(e),
|
| 488 |
+
'total_time': total_time,
|
| 489 |
+
'submit_time': start_time
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
# Save result in real-time
|
| 493 |
+
save_result_record(output_dir, error_result)
|
| 494 |
+
return error_result
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def format_bytes(bytes_size):
|
| 498 |
+
"""Format byte size"""
|
| 499 |
+
for unit in ['B', 'KB', 'MB', 'GB']:
|
| 500 |
+
if bytes_size < 1024.0:
|
| 501 |
+
return f"{bytes_size:.2f} {unit}"
|
| 502 |
+
bytes_size /= 1024.0
|
| 503 |
+
return f"{bytes_size:.2f} TB"
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def format_speed(bytes_per_sec):
|
| 507 |
+
"""Format speed"""
|
| 508 |
+
return f"{format_bytes(bytes_per_sec)}/s"
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def main():
|
| 512 |
+
"""Main program - two-phase concurrent processing"""
|
| 513 |
+
input_file = "cleaned_data_truncated.json"
|
| 514 |
+
output_dir = "sunov4_5_truncated"
|
| 515 |
+
# Create output directory
|
| 516 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 517 |
+
|
| 518 |
+
# Initialize logging
|
| 519 |
+
global logger
|
| 520 |
+
logger, log_file = setup_logging(output_dir)
|
| 521 |
+
|
| 522 |
+
print_log("=" * 70)
|
| 523 |
+
print_log("Suno API Batch Generation - V4.5 Special Edition")
|
| 524 |
+
print_log("Strategy: Fast submission (5 requests/10s) + Parallel waiting + Detailed performance analysis")
|
| 525 |
+
print_log(f"Log file: {log_file}")
|
| 526 |
+
print_log("=" * 70)
|
| 527 |
+
|
| 528 |
+
# Read input file
|
| 529 |
+
try:
|
| 530 |
+
all_data = []
|
| 531 |
+
if input_file.endswith('.jsonl'):
|
| 532 |
+
try:
|
| 533 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 534 |
+
# Try reading first line to determine format
|
| 535 |
+
first_line = f.readline().strip()
|
| 536 |
+
if first_line.startswith('['):
|
| 537 |
+
# Looks like regular JSON array
|
| 538 |
+
f.seek(0)
|
| 539 |
+
all_data = json.load(f)
|
| 540 |
+
else:
|
| 541 |
+
# Try reading line by line
|
| 542 |
+
f.seek(0)
|
| 543 |
+
for i, line in enumerate(f):
|
| 544 |
+
line = line.strip()
|
| 545 |
+
if line:
|
| 546 |
+
all_data.append(json.loads(line))
|
| 547 |
+
except json.JSONDecodeError:
|
| 548 |
+
# If above parsing fails, try one final read as regular JSON
|
| 549 |
+
print_log(f"Note: Failed to parse {input_file} as JSONL format, trying as regular JSON...")
|
| 550 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 551 |
+
all_data = json.load(f)
|
| 552 |
+
else:
|
| 553 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 554 |
+
all_data = json.load(f)
|
| 555 |
+
|
| 556 |
+
except FileNotFoundError:
|
| 557 |
+
print_log(f"File {input_file} not found.")
|
| 558 |
+
return
|
| 559 |
+
except json.JSONDecodeError as e:
|
| 560 |
+
print_log(f"JSON parsing error: {e}")
|
| 561 |
+
return
|
| 562 |
+
|
| 563 |
+
# Initialize API
|
| 564 |
+
api = SunoAPI(SUNO_API_KEY)
|
| 565 |
+
|
| 566 |
+
print_log(f"\nPreparing to generate {len(all_data)} songs...")
|
| 567 |
+
print_log(f"Start time: {time.strftime('%H:%M:%S')}\n")
|
| 568 |
+
|
| 569 |
+
overall_start_time = time.time()
|
| 570 |
+
|
| 571 |
+
# ===== Phase 1: Batch Submission =====
|
| 572 |
+
print_log("\n" + "=" * 70)
|
| 573 |
+
print_log("Phase 1: Batch Submission")
|
| 574 |
+
print_log("=" * 70 + "\n")
|
| 575 |
+
|
| 576 |
+
submit_start_time = time.time()
|
| 577 |
+
submitted_tasks = []
|
| 578 |
+
total_request_time = 0
|
| 579 |
+
|
| 580 |
+
# Adjust rate limit: maximum 5 requests per 10 seconds
|
| 581 |
+
rate_limiter.max_requests = 5
|
| 582 |
+
rate_limiter.time_window = 10
|
| 583 |
+
rate_limiter.request_times.clear()
|
| 584 |
+
print_log(f"Rate limit: {rate_limiter.max_requests} requests / {rate_limiter.time_window} seconds")
|
| 585 |
+
|
| 586 |
+
# Only submit tasks that need to run
|
| 587 |
+
tasks_to_run = []
|
| 588 |
+
for i, data in enumerate(all_data, 1):
|
| 589 |
+
tasks_to_run.append((i, data))
|
| 590 |
+
|
| 591 |
+
print_log(f"Number of tasks to submit: {len(tasks_to_run)}")
|
| 592 |
+
|
| 593 |
+
# Use thread pool for submission
|
| 594 |
+
# Submission concurrency is controlled by rate_limiter, can be set to 5
|
| 595 |
+
with ThreadPoolExecutor(max_workers=5) as executor:
|
| 596 |
+
submit_futures = {
|
| 597 |
+
executor.submit(submit_generation_task, api, idx, data): idx
|
| 598 |
+
for idx, data in tasks_to_run
|
| 599 |
+
}
|
| 600 |
+
|
| 601 |
+
with tqdm(total=len(tasks_to_run), desc="Submitting tasks", unit="song") as pbar:
|
| 602 |
+
for future in as_completed(submit_futures):
|
| 603 |
+
result = future.result()
|
| 604 |
+
submitted_tasks.append(result)
|
| 605 |
+
if result.get('success') and 'request_time' in result:
|
| 606 |
+
total_request_time += result['request_time']
|
| 607 |
+
pbar.update(1)
|
| 608 |
+
|
| 609 |
+
submit_phase_time = time.time() - submit_start_time
|
| 610 |
+
success_submits = sum(1 for t in submitted_tasks if t['success'])
|
| 611 |
+
|
| 612 |
+
# Get rate limit statistics
|
| 613 |
+
rate_limit_stats = rate_limiter.get_stats()
|
| 614 |
+
|
| 615 |
+
print_log(f"\nSubmission phase complete: {success_submits}/{len(tasks_to_run)} successful")
|
| 616 |
+
print_log(f" Total time: {submit_phase_time:.1f} seconds")
|
| 617 |
+
print_log(f" Actual request time: {total_request_time:.2f} seconds")
|
| 618 |
+
print_log(f" Rate limit waiting: {rate_limit_stats['total_wait_time']:.2f} seconds ({rate_limit_stats['wait_count']} times)")
|
| 619 |
+
if rate_limit_stats['wait_count'] > 0:
|
| 620 |
+
print_log(f" Average wait time: {rate_limit_stats['avg_wait_time']:.2f} seconds/time")
|
| 621 |
+
|
| 622 |
+
# ===== Phase 2: Parallel Waiting and Download =====
|
| 623 |
+
print_log("\n" + "=" * 70)
|
| 624 |
+
print_log("Phase 2: Wait for Generation and Download")
|
| 625 |
+
print_log("=" * 70 + "\n")
|
| 626 |
+
|
| 627 |
+
wait_start_time = time.time()
|
| 628 |
+
final_results = []
|
| 629 |
+
|
| 630 |
+
# Use more threads for parallel waiting (not rate limited)
|
| 631 |
+
with ThreadPoolExecutor(max_workers=20) as executor:
|
| 632 |
+
download_futures = {
|
| 633 |
+
executor.submit(wait_and_download_result, api, task, output_dir): task
|
| 634 |
+
for task in submitted_tasks if task['success']
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
# Add failed submission tasks to results
|
| 638 |
+
for task in submitted_tasks:
|
| 639 |
+
if not task['success']:
|
| 640 |
+
final_results.append(task)
|
| 641 |
+
|
| 642 |
+
with tqdm(total=len(download_futures), desc="Downloading results", unit="song") as pbar:
|
| 643 |
+
for future in as_completed(download_futures):
|
| 644 |
+
result = future.result()
|
| 645 |
+
final_results.append(result)
|
| 646 |
+
pbar.update(1)
|
| 647 |
+
|
| 648 |
+
wait_phase_time = time.time() - wait_start_time
|
| 649 |
+
|
| 650 |
+
# ===== Detailed Statistics and Report =====
|
| 651 |
+
overall_time = time.time() - overall_start_time
|
| 652 |
+
|
| 653 |
+
print_log("\n" + "=" * 70)
|
| 654 |
+
print_log("Batch Generation Complete - Detailed Performance Report")
|
| 655 |
+
print_log("=" * 70)
|
| 656 |
+
|
| 657 |
+
success_count = sum(1 for r in final_results if r.get('success'))
|
| 658 |
+
fail_count = len(final_results) - success_count
|
| 659 |
+
total_tracks = sum(r.get('tracks_count', 0) for r in final_results if r.get('success'))
|
| 660 |
+
|
| 661 |
+
successful_results = [r for r in final_results if r.get('success')]
|
| 662 |
+
|
| 663 |
+
# Basic Statistics
|
| 664 |
+
print_log(f"\n[Basic Statistics]")
|
| 665 |
+
print_log(f" Total songs: {len(all_data)}")
|
| 666 |
+
print_log(f" Successful: {success_count}")
|
| 667 |
+
print_log(f" Failed: {fail_count}")
|
| 668 |
+
print_log(f" Total tracks: {total_tracks}")
|
| 669 |
+
if success_count > 0:
|
| 670 |
+
avg_tracks = total_tracks / success_count
|
| 671 |
+
print_log(f" Average tracks per song: {avg_tracks:.2f}")
|
| 672 |
+
|
| 673 |
+
# Time Statistics
|
| 674 |
+
print_log(f"\n[Time Statistics]")
|
| 675 |
+
print_log(f" ├── Submission phase: {submit_phase_time:.1f} seconds")
|
| 676 |
+
print_log(f" │ ├── Actual request time: {total_request_time:.2f} seconds")
|
| 677 |
+
print_log(f" │ └── Rate limit waiting: {rate_limit_stats['total_wait_time']:.2f} seconds")
|
| 678 |
+
print_log(f" ├── Generation waiting phase: {wait_phase_time:.1f} seconds")
|
| 679 |
+
|
| 680 |
+
if successful_results:
|
| 681 |
+
wait_times = [r.get('wait_time', 0) for r in successful_results if 'wait_time' in r]
|
| 682 |
+
download_times = [r.get('download_time', 0) for r in successful_results if 'download_time' in r]
|
| 683 |
+
|
| 684 |
+
if wait_times:
|
| 685 |
+
avg_wait = sum(wait_times) / len(wait_times)
|
| 686 |
+
min_wait = min(wait_times)
|
| 687 |
+
max_wait = max(wait_times)
|
| 688 |
+
print_log(f" │ ├── Average wait time: {avg_wait:.1f} seconds/song")
|
| 689 |
+
print_log(f" │ ├── Fastest: {min_wait:.1f} seconds")
|
| 690 |
+
print_log(f" │ └── Slowest: {max_wait:.1f} seconds")
|
| 691 |
+
|
| 692 |
+
if download_times:
|
| 693 |
+
total_download_time = sum(download_times)
|
| 694 |
+
avg_download = total_download_time / len(download_times)
|
| 695 |
+
print_log(f" ├── Download phase: {total_download_time:.1f} seconds")
|
| 696 |
+
print_log(f" │ └── Average download time: {avg_download:.2f} seconds/song")
|
| 697 |
+
|
| 698 |
+
print_log(f" └── Total time: {overall_time:.1f} seconds ({overall_time/60:.1f} minutes)")
|
| 699 |
+
|
| 700 |
+
# Single Song Generation Statistics
|
| 701 |
+
if successful_results:
|
| 702 |
+
total_times = [r.get('total_time', 0) for r in successful_results if 'total_time' in r]
|
| 703 |
+
if total_times:
|
| 704 |
+
print_log(f"\n[Single Song Generation Statistics]")
|
| 705 |
+
avg_time = sum(total_times) / len(total_times)
|
| 706 |
+
min_time = min(total_times)
|
| 707 |
+
max_time = max(total_times)
|
| 708 |
+
print_log(f" Average total time per song: {avg_time:.1f} seconds")
|
| 709 |
+
print_log(f" Fastest generation: {min_time:.1f} seconds")
|
| 710 |
+
print_log(f" Slowest generation: {max_time:.1f} seconds")
|
| 711 |
+
|
| 712 |
+
# Download Statistics
|
| 713 |
+
total_download_bytes = sum(r.get('download_bytes', 0) for r in successful_results)
|
| 714 |
+
total_download_count = sum(r.get('download_count', 0) for r in successful_results)
|
| 715 |
+
|
| 716 |
+
if total_download_bytes > 0:
|
| 717 |
+
print_log(f"\n[Download Statistics]")
|
| 718 |
+
print_log(f" Total download: {format_bytes(total_download_bytes)}")
|
| 719 |
+
print_log(f" Number of files: {total_download_count}")
|
| 720 |
+
print_log(f" Average file size: {format_bytes(total_download_bytes / total_download_count)}")
|
| 721 |
+
|
| 722 |
+
download_speeds = [r.get('avg_download_speed', 0) for r in successful_results if r.get('avg_download_speed', 0) > 0]
|
| 723 |
+
if download_speeds:
|
| 724 |
+
avg_speed = sum(download_speeds) / len(download_speeds)
|
| 725 |
+
print_log(f" Average download speed: {format_speed(avg_speed)}")
|
| 726 |
+
|
| 727 |
+
# Polling Statistics
|
| 728 |
+
poll_counts = [r.get('poll_count', 0) for r in successful_results if 'poll_count' in r]
|
| 729 |
+
if poll_counts:
|
| 730 |
+
total_polls = sum(poll_counts)
|
| 731 |
+
avg_polls = total_polls / len(poll_counts)
|
| 732 |
+
print_log(f"\n[Polling Statistics]")
|
| 733 |
+
print_log(f" Total polling count: {total_polls}")
|
| 734 |
+
print_log(f" Average polling per song: {avg_polls:.1f}")
|
| 735 |
+
|
| 736 |
+
# Efficiency Analysis
|
| 737 |
+
print_log(f"\n[Efficiency Analysis]")
|
| 738 |
+
if success_count > 0:
|
| 739 |
+
throughput = success_count / (overall_time / 60)
|
| 740 |
+
print_log(f" Actual throughput: {throughput:.2f} songs/minute")
|
| 741 |
+
|
| 742 |
+
# Theoretical fastest time (assuming no rate limit)
|
| 743 |
+
if wait_times:
|
| 744 |
+
ideal_time = submit_phase_time - rate_limit_stats['total_wait_time'] + max(wait_times)
|
| 745 |
+
efficiency = (ideal_time / overall_time) * 100
|
| 746 |
+
print_log(f" Theoretical fastest time: {ideal_time:.1f} seconds")
|
| 747 |
+
print_log(f" Concurrency efficiency: {efficiency:.1f}%")
|
| 748 |
+
|
| 749 |
+
# Show failed songs
|
| 750 |
+
if fail_count > 0:
|
| 751 |
+
print_log("\n" + "=" * 70)
|
| 752 |
+
print_log("Failed Songs List")
|
| 753 |
+
print_log("=" * 70)
|
| 754 |
+
for r in sorted(final_results, key=lambda x: x.get('song_index', 0)):
|
| 755 |
+
if not r.get('success'):
|
| 756 |
+
song_id = r.get('song_id', r.get('song_index', 'Unknown'))
|
| 757 |
+
print_log(f" [{song_id}] {r.get('error', 'Unknown error')}")
|
| 758 |
+
|
| 759 |
+
print_log("\n" + "=" * 70)
|
| 760 |
+
print_log(f"All files saved to: {os.path.abspath(output_dir)}")
|
| 761 |
+
print_log("=" * 70)
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
if __name__ == '__main__':
|
| 765 |
+
main()
|
| 766 |
+
|
baseline_generate/suno/suno_5.py
ADDED
|
@@ -0,0 +1,768 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Suno API Batch Generation - V5 Version (5 requests per 10 seconds)
|
| 4 |
+
Changes:
|
| 5 |
+
1. Rate control: 5 requests within 10 seconds
|
| 6 |
+
2. ID format: sunov5_000001
|
| 7 |
+
"""
|
| 8 |
+
import json
|
| 9 |
+
import time
|
| 10 |
+
import requests
|
| 11 |
+
import os
|
| 12 |
+
import logging
|
| 13 |
+
import csv
|
| 14 |
+
from requests.adapters import HTTPAdapter
|
| 15 |
+
from urllib3.util.retry import Retry
|
| 16 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 17 |
+
from collections import deque
|
| 18 |
+
from threading import Lock, Semaphore
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
import sys
|
| 21 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 22 |
+
from config import SUNO_API_KEY
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Configure logging
|
| 26 |
+
def setup_logging(output_dir):
|
| 27 |
+
log_file = os.path.join(output_dir, f"run_log_{time.strftime('%Y%m%d_%H%M%S')}.txt")
|
| 28 |
+
|
| 29 |
+
# Create logger
|
| 30 |
+
logger = logging.getLogger('SunoBatchV5')
|
| 31 |
+
logger.setLevel(logging.INFO)
|
| 32 |
+
|
| 33 |
+
# Clear old handlers
|
| 34 |
+
if logger.hasHandlers():
|
| 35 |
+
logger.handlers.clear()
|
| 36 |
+
|
| 37 |
+
# File Handler
|
| 38 |
+
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
| 39 |
+
file_handler.setFormatter(logging.Formatter('%(message)s'))
|
| 40 |
+
logger.addHandler(file_handler)
|
| 41 |
+
|
| 42 |
+
# Console Handler
|
| 43 |
+
console_handler = logging.StreamHandler()
|
| 44 |
+
console_handler.setFormatter(logging.Formatter('%(message)s'))
|
| 45 |
+
logger.addHandler(console_handler)
|
| 46 |
+
|
| 47 |
+
return logger, log_file
|
| 48 |
+
|
| 49 |
+
# Global logger
|
| 50 |
+
logger = logging.getLogger('SunoBatchV5')
|
| 51 |
+
|
| 52 |
+
# Replace print with logger.info
|
| 53 |
+
def print_log(msg):
|
| 54 |
+
logger.info(msg)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SunoAPI:
|
| 58 |
+
"""Simplified Suno API client"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, api_key):
|
| 61 |
+
self.api_key = api_key
|
| 62 |
+
self.base_url = 'https://api.sunoapi.org/api/v1'
|
| 63 |
+
self.headers = {
|
| 64 |
+
'Authorization': f'Bearer {api_key}',
|
| 65 |
+
'Content-Type': 'application/json'
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# Configure retry strategy
|
| 69 |
+
self.session = requests.Session()
|
| 70 |
+
retry_strategy = Retry(
|
| 71 |
+
total=5, # Maximum retry count
|
| 72 |
+
backoff_factor=1, # Retry interval (1s, 2s, 4s, 8s...)
|
| 73 |
+
status_forcelist=[500, 502, 503, 504], # Status codes that need retry
|
| 74 |
+
allowed_methods=["HEAD", "GET", "POST", "OPTIONS"] # Allowed retry methods
|
| 75 |
+
)
|
| 76 |
+
adapter = HTTPAdapter(max_retries=retry_strategy)
|
| 77 |
+
self.session.mount("https://", adapter)
|
| 78 |
+
self.session.mount("http://", adapter)
|
| 79 |
+
|
| 80 |
+
def generate_music(self, prompt, model='V5', vocalGender=None, **options):
|
| 81 |
+
"""Generate music"""
|
| 82 |
+
payload = {
|
| 83 |
+
'prompt': prompt,
|
| 84 |
+
'model': model,
|
| 85 |
+
'callBackUrl': 'https://example.com/callback',
|
| 86 |
+
**options
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
if vocalGender:
|
| 90 |
+
payload['vocalGender'] = vocalGender
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
response = self.session.post(
|
| 94 |
+
f'{self.base_url}/generate',
|
| 95 |
+
headers=self.headers,
|
| 96 |
+
json=payload,
|
| 97 |
+
timeout=30
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Check HTTP errors
|
| 101 |
+
response.raise_for_status()
|
| 102 |
+
|
| 103 |
+
# Try to parse JSON
|
| 104 |
+
try:
|
| 105 |
+
result = response.json()
|
| 106 |
+
except json.JSONDecodeError:
|
| 107 |
+
raise Exception(f"API returned non-JSON response: {response.text[:200]}")
|
| 108 |
+
|
| 109 |
+
if result.get('code') != 200:
|
| 110 |
+
raise Exception(f"Generation failed: {result.get('msg', result)}")
|
| 111 |
+
|
| 112 |
+
return result['data']['taskId']
|
| 113 |
+
|
| 114 |
+
except requests.exceptions.RequestException as e:
|
| 115 |
+
raise Exception(f"Request exception: {str(e)}")
|
| 116 |
+
|
| 117 |
+
def get_task_status(self, task_id):
|
| 118 |
+
"""Get task status"""
|
| 119 |
+
try:
|
| 120 |
+
response = self.session.get(
|
| 121 |
+
f'{self.base_url}/generate/record-info?taskId={task_id}',
|
| 122 |
+
headers={'Authorization': f'Bearer {self.api_key}'},
|
| 123 |
+
timeout=30
|
| 124 |
+
)
|
| 125 |
+
response.raise_for_status()
|
| 126 |
+
return response.json().get('data', {})
|
| 127 |
+
except Exception as e:
|
| 128 |
+
# Status query failure should not crash the program, return empty dict or throw specific exception
|
| 129 |
+
# print_log(f"Failed to get status: {e}")
|
| 130 |
+
raise e
|
| 131 |
+
|
| 132 |
+
def get_timestamped_lyrics(self, task_id, audio_id):
|
| 133 |
+
"""Get timestamped lyrics"""
|
| 134 |
+
payload = {
|
| 135 |
+
'taskId': task_id,
|
| 136 |
+
'audioId': audio_id
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
response = self.session.post(
|
| 141 |
+
f'{self.base_url}/generate/get-timestamped-lyrics',
|
| 142 |
+
headers=self.headers,
|
| 143 |
+
json=payload,
|
| 144 |
+
timeout=30
|
| 145 |
+
)
|
| 146 |
+
response.raise_for_status()
|
| 147 |
+
return response.json()
|
| 148 |
+
except Exception:
|
| 149 |
+
return {} # Lyrics retrieval failure is non-fatal error
|
| 150 |
+
|
| 151 |
+
def wait_for_completion(self, task_id, max_wait_time=600, check_interval=5):
|
| 152 |
+
"""Wait for task completion, return result and polling statistics"""
|
| 153 |
+
start_time = time.time()
|
| 154 |
+
poll_count = 0
|
| 155 |
+
total_poll_time = 0
|
| 156 |
+
|
| 157 |
+
while time.time() - start_time < max_wait_time:
|
| 158 |
+
try:
|
| 159 |
+
poll_start = time.time()
|
| 160 |
+
status = self.get_task_status(task_id)
|
| 161 |
+
poll_time = time.time() - poll_start
|
| 162 |
+
poll_count += 1
|
| 163 |
+
total_poll_time += poll_time
|
| 164 |
+
|
| 165 |
+
current_status = status.get('status')
|
| 166 |
+
|
| 167 |
+
if current_status == 'SUCCESS':
|
| 168 |
+
return {
|
| 169 |
+
'result': status.get('response'),
|
| 170 |
+
'wait_time': time.time() - start_time,
|
| 171 |
+
'poll_count': poll_count,
|
| 172 |
+
'avg_poll_time': total_poll_time / poll_count if poll_count > 0 else 0
|
| 173 |
+
}
|
| 174 |
+
elif current_status == 'FAILED':
|
| 175 |
+
raise Exception(f"Task failed: {status.get('errorMessage')}")
|
| 176 |
+
|
| 177 |
+
time.sleep(check_interval)
|
| 178 |
+
except Exception as e:
|
| 179 |
+
if time.time() - start_time >= max_wait_time:
|
| 180 |
+
raise
|
| 181 |
+
time.sleep(check_interval)
|
| 182 |
+
|
| 183 |
+
raise Exception('Task timeout')
|
| 184 |
+
|
| 185 |
+
def download_file(self, url, save_path):
|
| 186 |
+
"""Download file to local, return download statistics"""
|
| 187 |
+
try:
|
| 188 |
+
start_time = time.time()
|
| 189 |
+
downloaded_bytes = 0
|
| 190 |
+
|
| 191 |
+
# Use session to download
|
| 192 |
+
with self.session.get(url, stream=True, timeout=60) as r:
|
| 193 |
+
r.raise_for_status()
|
| 194 |
+
with open(save_path, 'wb') as f:
|
| 195 |
+
for chunk in r.iter_content(chunk_size=8192):
|
| 196 |
+
f.write(chunk)
|
| 197 |
+
downloaded_bytes += len(chunk)
|
| 198 |
+
|
| 199 |
+
download_time = time.time() - start_time
|
| 200 |
+
return {
|
| 201 |
+
'success': True,
|
| 202 |
+
'bytes': downloaded_bytes,
|
| 203 |
+
'time': download_time,
|
| 204 |
+
'speed': downloaded_bytes / download_time if download_time > 0 else 0
|
| 205 |
+
}
|
| 206 |
+
except Exception as e:
|
| 207 |
+
print_log(f"Download failed {url}: {e}")
|
| 208 |
+
return {'success': False, 'error': str(e)}
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# Result record lock
|
| 212 |
+
result_lock = Lock()
|
| 213 |
+
|
| 214 |
+
def save_result_record(output_dir, record):
|
| 215 |
+
"""Save single result to CSV in real-time"""
|
| 216 |
+
file_path = os.path.join(output_dir, "generation_results.csv")
|
| 217 |
+
file_exists = os.path.isfile(file_path)
|
| 218 |
+
|
| 219 |
+
# Only record key information
|
| 220 |
+
row = {
|
| 221 |
+
'song_id': record.get('song_id'),
|
| 222 |
+
'task_id': record.get('task_id'),
|
| 223 |
+
'status': 'SUCCESS' if record.get('success') else 'FAILED',
|
| 224 |
+
'error': record.get('error', ''),
|
| 225 |
+
'submit_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(record.get('submit_time', 0))),
|
| 226 |
+
'total_time': f"{record.get('total_time', 0):.1f}",
|
| 227 |
+
'tracks_count': record.get('tracks_count', 0)
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
with result_lock:
|
| 231 |
+
with open(file_path, 'a', newline='', encoding='utf-8') as f:
|
| 232 |
+
writer = csv.DictWriter(f, fieldnames=['song_id', 'task_id', 'status', 'error', 'submit_time', 'total_time', 'tracks_count'])
|
| 233 |
+
if not file_exists:
|
| 234 |
+
writer.writeheader()
|
| 235 |
+
writer.writerow(row)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class ImprovedRateLimiter:
|
| 239 |
+
"""Improved rate limiter (with statistics)
|
| 240 |
+
|
| 241 |
+
Precise control: maximum 8 requests per 10 seconds
|
| 242 |
+
Uses sliding window algorithm to ensure no more than 8 requests in any 10-second time window
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
def __init__(self, max_requests=5, time_window=10):
|
| 246 |
+
self.max_requests = max_requests
|
| 247 |
+
self.time_window = time_window
|
| 248 |
+
self.request_times = deque()
|
| 249 |
+
self.lock = Lock()
|
| 250 |
+
self.semaphore = Semaphore(max_requests)
|
| 251 |
+
|
| 252 |
+
# Statistics
|
| 253 |
+
self.total_wait_time = 0
|
| 254 |
+
self.wait_count = 0
|
| 255 |
+
self.total_requests = 0
|
| 256 |
+
|
| 257 |
+
def acquire(self):
|
| 258 |
+
"""Acquire request permission"""
|
| 259 |
+
with self.lock:
|
| 260 |
+
now = time.time()
|
| 261 |
+
|
| 262 |
+
# Clean expired request records
|
| 263 |
+
while self.request_times and now - self.request_times[0] >= self.time_window:
|
| 264 |
+
self.request_times.popleft()
|
| 265 |
+
|
| 266 |
+
# If limit reached, calculate wait time needed
|
| 267 |
+
wait_time = 0
|
| 268 |
+
if len(self.request_times) >= self.max_requests:
|
| 269 |
+
oldest_request = self.request_times[0]
|
| 270 |
+
wait_time = self.time_window - (now - oldest_request) + 0.05 # Add buffer
|
| 271 |
+
|
| 272 |
+
if wait_time > 0:
|
| 273 |
+
print_log(f" [Rate Limit] Waiting {wait_time:.2f} seconds...")
|
| 274 |
+
time.sleep(wait_time)
|
| 275 |
+
|
| 276 |
+
# Record wait time
|
| 277 |
+
self.total_wait_time += wait_time
|
| 278 |
+
self.wait_count += 1
|
| 279 |
+
|
| 280 |
+
# Re-clean
|
| 281 |
+
now = time.time()
|
| 282 |
+
while self.request_times and now - self.request_times[0] >= self.time_window:
|
| 283 |
+
self.request_times.popleft()
|
| 284 |
+
|
| 285 |
+
# Record this request time
|
| 286 |
+
self.request_times.append(time.time())
|
| 287 |
+
self.total_requests += 1
|
| 288 |
+
|
| 289 |
+
def get_current_rate(self):
|
| 290 |
+
"""Get current rate (number of requests in last 10 seconds)"""
|
| 291 |
+
with self.lock:
|
| 292 |
+
now = time.time()
|
| 293 |
+
while self.request_times and now - self.request_times[0] >= self.time_window:
|
| 294 |
+
self.request_times.popleft()
|
| 295 |
+
return len(self.request_times)
|
| 296 |
+
|
| 297 |
+
def get_stats(self):
|
| 298 |
+
"""Get statistics"""
|
| 299 |
+
with self.lock:
|
| 300 |
+
return {
|
| 301 |
+
'total_requests': self.total_requests,
|
| 302 |
+
'total_wait_time': self.total_wait_time,
|
| 303 |
+
'wait_count': self.wait_count,
|
| 304 |
+
'avg_wait_time': self.total_wait_time / self.wait_count if self.wait_count > 0 else 0
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# Global rate limiter (5 requests per 10 seconds)
|
| 309 |
+
rate_limiter = ImprovedRateLimiter(max_requests=5, time_window=10)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def submit_generation_task(api, song_index, data):
|
| 313 |
+
"""Phase 1: Submit generation task (rate limited)"""
|
| 314 |
+
# Use sunov5_000001 format
|
| 315 |
+
song_id = data.get("id", f"sunov5_{song_index:06d}")
|
| 316 |
+
|
| 317 |
+
try:
|
| 318 |
+
description = data.get("description", "")
|
| 319 |
+
lyrics = data.get("lyrics", "")
|
| 320 |
+
vocal_gender = data.get("vocalGender")
|
| 321 |
+
|
| 322 |
+
print_log(f"[Song {song_id}] Submitting task... (current rate: {rate_limiter.get_current_rate()}/5)")
|
| 323 |
+
|
| 324 |
+
# Record request start time
|
| 325 |
+
request_start = time.time()
|
| 326 |
+
|
| 327 |
+
# Rate limiting
|
| 328 |
+
rate_limiter.acquire()
|
| 329 |
+
|
| 330 |
+
# Submit task
|
| 331 |
+
submit_start = time.time()
|
| 332 |
+
task_id = api.generate_music(
|
| 333 |
+
prompt=lyrics,
|
| 334 |
+
style=description,
|
| 335 |
+
title=f"Song_{song_id}",
|
| 336 |
+
model='V5',
|
| 337 |
+
customMode=True,
|
| 338 |
+
instrumental=False,
|
| 339 |
+
vocalGender=vocal_gender
|
| 340 |
+
)
|
| 341 |
+
request_time = time.time() - submit_start
|
| 342 |
+
|
| 343 |
+
print_log(f"[Song {song_id}] ✓ Task submitted, ID: {task_id}")
|
| 344 |
+
|
| 345 |
+
return {
|
| 346 |
+
'song_id': song_id,
|
| 347 |
+
'song_index': song_index,
|
| 348 |
+
'task_id': task_id,
|
| 349 |
+
'data': data,
|
| 350 |
+
'submit_time': time.time(),
|
| 351 |
+
'request_time': request_time,
|
| 352 |
+
'success': True
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
except Exception as e:
|
| 356 |
+
print_log(f"[Song {song_id}] ✗ Submission failed: {e}")
|
| 357 |
+
# If submission fails, also record it (even though not at download stage yet)
|
| 358 |
+
return {
|
| 359 |
+
'song_id': song_id,
|
| 360 |
+
'song_index': song_index,
|
| 361 |
+
'success': False,
|
| 362 |
+
'error': str(e)
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def wait_and_download_result(api, task_info, output_dir):
|
| 367 |
+
"""Phase 2: Wait for result and download (not rate limited)"""
|
| 368 |
+
if not task_info['success']:
|
| 369 |
+
return task_info
|
| 370 |
+
|
| 371 |
+
song_id = task_info['song_id']
|
| 372 |
+
song_index = task_info['song_index']
|
| 373 |
+
task_id = task_info['task_id']
|
| 374 |
+
data = task_info['data']
|
| 375 |
+
start_time = task_info['submit_time']
|
| 376 |
+
|
| 377 |
+
try:
|
| 378 |
+
original_lyrics = data.get("original_lyrics", data.get("lyrics", ""))
|
| 379 |
+
lyrics = data.get("lyrics", "")
|
| 380 |
+
description = data.get("description", "")
|
| 381 |
+
|
| 382 |
+
print_log(f"[Song {song_id}] Waiting for generation to complete...")
|
| 383 |
+
|
| 384 |
+
# Wait for completion (returns detailed statistics)
|
| 385 |
+
wait_result = api.wait_for_completion(task_id, max_wait_time=600, check_interval=8)
|
| 386 |
+
result = wait_result['result']
|
| 387 |
+
|
| 388 |
+
# Process returned result
|
| 389 |
+
tracks = []
|
| 390 |
+
if isinstance(result, dict):
|
| 391 |
+
if 'data' in result:
|
| 392 |
+
tracks = result['data']
|
| 393 |
+
elif 'sunoData' in result:
|
| 394 |
+
tracks = result['sunoData']
|
| 395 |
+
else:
|
| 396 |
+
for key, value in result.items():
|
| 397 |
+
if isinstance(value, list) and len(value) > 0 and 'audioUrl' in value[0]:
|
| 398 |
+
tracks = value
|
| 399 |
+
break
|
| 400 |
+
|
| 401 |
+
if not tracks:
|
| 402 |
+
raise Exception("Audio track data not found")
|
| 403 |
+
|
| 404 |
+
# Download phase statistics
|
| 405 |
+
download_start = time.time()
|
| 406 |
+
downloaded_files = []
|
| 407 |
+
total_download_bytes = 0
|
| 408 |
+
download_count = 0
|
| 409 |
+
|
| 410 |
+
# Process each track
|
| 411 |
+
for track_idx, track in enumerate(tracks):
|
| 412 |
+
audio_url = track.get('audioUrl') or track.get('audio_url')
|
| 413 |
+
audio_id = track.get('id')
|
| 414 |
+
|
| 415 |
+
base_filename = f"{song_id}_{track_idx}"
|
| 416 |
+
audio_path = os.path.join(output_dir, f"{base_filename}.mp3")
|
| 417 |
+
lyrics_path = os.path.join(output_dir, f"{base_filename}_lyrics.json")
|
| 418 |
+
|
| 419 |
+
# Download audio
|
| 420 |
+
if audio_url:
|
| 421 |
+
download_result = api.download_file(audio_url, audio_path)
|
| 422 |
+
if download_result['success']:
|
| 423 |
+
downloaded_files.append(audio_path)
|
| 424 |
+
total_download_bytes += download_result['bytes']
|
| 425 |
+
download_count += 1
|
| 426 |
+
|
| 427 |
+
# Get timestamped lyrics
|
| 428 |
+
timestamped_lyrics_data = None
|
| 429 |
+
if audio_id:
|
| 430 |
+
try:
|
| 431 |
+
lyrics_response = api.get_timestamped_lyrics(task_id, audio_id)
|
| 432 |
+
if lyrics_response.get('code') == 200:
|
| 433 |
+
timestamped_lyrics_data = lyrics_response.get('data')
|
| 434 |
+
except Exception as e:
|
| 435 |
+
print_log(f"[Song {song_id}] Track {track_idx+1}: Failed to get lyrics: {e}")
|
| 436 |
+
|
| 437 |
+
# Save lyrics and metadata
|
| 438 |
+
lyrics_content = {
|
| 439 |
+
"song_id": song_id,
|
| 440 |
+
"song_index": song_index,
|
| 441 |
+
"track_index": track_idx,
|
| 442 |
+
"original_lyrics": original_lyrics,
|
| 443 |
+
"cleaned_lyrics": lyrics,
|
| 444 |
+
"timestamped_lyrics": timestamped_lyrics_data,
|
| 445 |
+
"style": description,
|
| 446 |
+
"full_track_data": track
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
with open(lyrics_path, 'w', encoding='utf-8') as f:
|
| 450 |
+
json.dump(lyrics_content, f, ensure_ascii=False, indent=2)
|
| 451 |
+
downloaded_files.append(lyrics_path)
|
| 452 |
+
|
| 453 |
+
download_time = time.time() - download_start
|
| 454 |
+
total_time = time.time() - start_time
|
| 455 |
+
|
| 456 |
+
print_log(f"[Song {song_id}] ✓ Complete! {len(tracks)} tracks, took {total_time:.1f} seconds")
|
| 457 |
+
|
| 458 |
+
final_result = {
|
| 459 |
+
'song_id': song_id,
|
| 460 |
+
'song_index': song_index,
|
| 461 |
+
'task_id': task_id,
|
| 462 |
+
'success': True,
|
| 463 |
+
'tracks_count': len(tracks),
|
| 464 |
+
'files': downloaded_files,
|
| 465 |
+
'total_time': total_time,
|
| 466 |
+
'submit_time': start_time,
|
| 467 |
+
'wait_time': wait_result['wait_time'],
|
| 468 |
+
'poll_count': wait_result['poll_count'],
|
| 469 |
+
'avg_poll_time': wait_result['avg_poll_time'],
|
| 470 |
+
'download_time': download_time,
|
| 471 |
+
'download_bytes': total_download_bytes,
|
| 472 |
+
'download_count': download_count,
|
| 473 |
+
'avg_download_speed': total_download_bytes / download_time if download_time > 0 else 0
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
# Save result in real-time
|
| 477 |
+
save_result_record(output_dir, final_result)
|
| 478 |
+
return final_result
|
| 479 |
+
|
| 480 |
+
except Exception as e:
|
| 481 |
+
total_time = time.time() - start_time
|
| 482 |
+
print_log(f"[Song {song_id}] ✗ Processing failed: {e} (took {total_time:.1f} seconds)")
|
| 483 |
+
|
| 484 |
+
error_result = {
|
| 485 |
+
'song_id': song_id,
|
| 486 |
+
'song_index': song_index,
|
| 487 |
+
'task_id': task_id,
|
| 488 |
+
'success': False,
|
| 489 |
+
'error': str(e),
|
| 490 |
+
'total_time': total_time,
|
| 491 |
+
'submit_time': start_time
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
# Save result in real-time
|
| 495 |
+
save_result_record(output_dir, error_result)
|
| 496 |
+
return error_result
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def format_bytes(bytes_size):
|
| 500 |
+
"""Format byte size"""
|
| 501 |
+
for unit in ['B', 'KB', 'MB', 'GB']:
|
| 502 |
+
if bytes_size < 1024.0:
|
| 503 |
+
return f"{bytes_size:.2f} {unit}"
|
| 504 |
+
bytes_size /= 1024.0
|
| 505 |
+
return f"{bytes_size:.2f} TB"
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def format_speed(bytes_per_sec):
|
| 509 |
+
"""Format speed"""
|
| 510 |
+
return f"{format_bytes(bytes_per_sec)}/s"
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def main():
|
| 514 |
+
"""Main program - two-phase concurrent processing"""
|
| 515 |
+
input_file = "cleaned_data_truncated.json"
|
| 516 |
+
output_dir = "sunov5_truncated"
|
| 517 |
+
# Create output directory
|
| 518 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 519 |
+
|
| 520 |
+
# Initialize logging
|
| 521 |
+
global logger
|
| 522 |
+
logger, log_file = setup_logging(output_dir)
|
| 523 |
+
|
| 524 |
+
print_log("=" * 70)
|
| 525 |
+
print_log("Suno API Batch Generation - V5 Special Edition")
|
| 526 |
+
print_log("Strategy: Fast submission (5 requests/10s) + Parallel waiting + Detailed performance analysis")
|
| 527 |
+
print_log(f"Log file: {log_file}")
|
| 528 |
+
print_log("=" * 70)
|
| 529 |
+
|
| 530 |
+
# Read input file
|
| 531 |
+
try:
|
| 532 |
+
all_data = []
|
| 533 |
+
if input_file.endswith('.jsonl'):
|
| 534 |
+
try:
|
| 535 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 536 |
+
# Try reading first line to determine format
|
| 537 |
+
first_line = f.readline().strip()
|
| 538 |
+
if first_line.startswith('['):
|
| 539 |
+
# Looks like regular JSON array
|
| 540 |
+
f.seek(0)
|
| 541 |
+
all_data = json.load(f)
|
| 542 |
+
else:
|
| 543 |
+
# Try reading line by line
|
| 544 |
+
f.seek(0)
|
| 545 |
+
for i, line in enumerate(f):
|
| 546 |
+
line = line.strip()
|
| 547 |
+
if line:
|
| 548 |
+
all_data.append(json.loads(line))
|
| 549 |
+
except json.JSONDecodeError:
|
| 550 |
+
# If above parsing fails, try one final read as regular JSON
|
| 551 |
+
print_log(f"Note: Failed to parse {input_file} as JSONL format, trying as regular JSON...")
|
| 552 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 553 |
+
all_data = json.load(f)
|
| 554 |
+
else:
|
| 555 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 556 |
+
all_data = json.load(f)
|
| 557 |
+
|
| 558 |
+
except FileNotFoundError:
|
| 559 |
+
print_log(f"File {input_file} not found.")
|
| 560 |
+
return
|
| 561 |
+
except json.JSONDecodeError as e:
|
| 562 |
+
print_log(f"JSON parsing error: {e}")
|
| 563 |
+
return
|
| 564 |
+
|
| 565 |
+
# Initialize API
|
| 566 |
+
api = SunoAPI(SUNO_API_KEY)
|
| 567 |
+
|
| 568 |
+
print_log(f"\nPreparing to generate {len(all_data)} songs...")
|
| 569 |
+
print_log(f"Start time: {time.strftime('%H:%M:%S')}\n")
|
| 570 |
+
|
| 571 |
+
overall_start_time = time.time()
|
| 572 |
+
|
| 573 |
+
# ===== Phase 1: Batch Submission =====
|
| 574 |
+
print_log("\n" + "=" * 70)
|
| 575 |
+
print_log("Phase 1: Batch Submission")
|
| 576 |
+
print_log("=" * 70 + "\n")
|
| 577 |
+
|
| 578 |
+
submit_start_time = time.time()
|
| 579 |
+
submitted_tasks = []
|
| 580 |
+
total_request_time = 0
|
| 581 |
+
|
| 582 |
+
# Adjust rate limit: maximum 5 requests per 10 seconds
|
| 583 |
+
rate_limiter.max_requests = 5
|
| 584 |
+
rate_limiter.time_window = 10
|
| 585 |
+
rate_limiter.request_times.clear()
|
| 586 |
+
print_log(f"Rate limit: {rate_limiter.max_requests} requests / {rate_limiter.time_window} seconds")
|
| 587 |
+
|
| 588 |
+
# Only submit tasks that need to run
|
| 589 |
+
tasks_to_run = []
|
| 590 |
+
for i, data in enumerate(all_data, 1):
|
| 591 |
+
tasks_to_run.append((i, data))
|
| 592 |
+
|
| 593 |
+
print_log(f"Number of tasks to submit: {len(tasks_to_run)}")
|
| 594 |
+
|
| 595 |
+
# Use thread pool for submission
|
| 596 |
+
# Submission concurrency is controlled by rate_limiter, can be set to 5
|
| 597 |
+
with ThreadPoolExecutor(max_workers=5) as executor:
|
| 598 |
+
submit_futures = {
|
| 599 |
+
executor.submit(submit_generation_task, api, idx, data): idx
|
| 600 |
+
for idx, data in tasks_to_run
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
with tqdm(total=len(tasks_to_run), desc="Submitting tasks", unit="song") as pbar:
|
| 604 |
+
for future in as_completed(submit_futures):
|
| 605 |
+
result = future.result()
|
| 606 |
+
submitted_tasks.append(result)
|
| 607 |
+
if result.get('success') and 'request_time' in result:
|
| 608 |
+
total_request_time += result['request_time']
|
| 609 |
+
pbar.update(1)
|
| 610 |
+
|
| 611 |
+
submit_phase_time = time.time() - submit_start_time
|
| 612 |
+
success_submits = sum(1 for t in submitted_tasks if t['success'])
|
| 613 |
+
|
| 614 |
+
# Get rate limit statistics
|
| 615 |
+
rate_limit_stats = rate_limiter.get_stats()
|
| 616 |
+
|
| 617 |
+
print_log(f"\nSubmission phase complete: {success_submits}/{len(tasks_to_run)} successful")
|
| 618 |
+
print_log(f" Total time: {submit_phase_time:.1f} seconds")
|
| 619 |
+
print_log(f" Actual request time: {total_request_time:.2f} seconds")
|
| 620 |
+
print_log(f" Rate limit waiting: {rate_limit_stats['total_wait_time']:.2f} seconds ({rate_limit_stats['wait_count']} times)")
|
| 621 |
+
if rate_limit_stats['wait_count'] > 0:
|
| 622 |
+
print_log(f" Average wait time: {rate_limit_stats['avg_wait_time']:.2f} seconds/time")
|
| 623 |
+
|
| 624 |
+
# ===== Phase 2: Parallel Waiting and Download =====
|
| 625 |
+
print_log("\n" + "=" * 70)
|
| 626 |
+
print_log("Phase 2: Wait for Generation and Download")
|
| 627 |
+
print_log("=" * 70 + "\n")
|
| 628 |
+
|
| 629 |
+
wait_start_time = time.time()
|
| 630 |
+
final_results = []
|
| 631 |
+
|
| 632 |
+
# Use more threads for parallel waiting (not rate limited)
|
| 633 |
+
with ThreadPoolExecutor(max_workers=20) as executor:
|
| 634 |
+
download_futures = {
|
| 635 |
+
executor.submit(wait_and_download_result, api, task, output_dir): task
|
| 636 |
+
for task in submitted_tasks if task['success']
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
# Add failed submission tasks to results
|
| 640 |
+
for task in submitted_tasks:
|
| 641 |
+
if not task['success']:
|
| 642 |
+
final_results.append(task)
|
| 643 |
+
|
| 644 |
+
with tqdm(total=len(download_futures), desc="Downloading results", unit="song") as pbar:
|
| 645 |
+
for future in as_completed(download_futures):
|
| 646 |
+
result = future.result()
|
| 647 |
+
final_results.append(result)
|
| 648 |
+
pbar.update(1)
|
| 649 |
+
|
| 650 |
+
wait_phase_time = time.time() - wait_start_time
|
| 651 |
+
|
| 652 |
+
# ===== Detailed Statistics and Report =====
|
| 653 |
+
overall_time = time.time() - overall_start_time
|
| 654 |
+
|
| 655 |
+
print_log("\n" + "=" * 70)
|
| 656 |
+
print_log("Batch Generation Complete - Detailed Performance Report")
|
| 657 |
+
print_log("=" * 70)
|
| 658 |
+
|
| 659 |
+
success_count = sum(1 for r in final_results if r.get('success'))
|
| 660 |
+
fail_count = len(final_results) - success_count
|
| 661 |
+
total_tracks = sum(r.get('tracks_count', 0) for r in final_results if r.get('success'))
|
| 662 |
+
|
| 663 |
+
successful_results = [r for r in final_results if r.get('success')]
|
| 664 |
+
|
| 665 |
+
# Basic Statistics
|
| 666 |
+
print_log(f"\n[Basic Statistics]")
|
| 667 |
+
print_log(f" Total songs: {len(all_data)}")
|
| 668 |
+
print_log(f" Successful: {success_count}")
|
| 669 |
+
print_log(f" Failed: {fail_count}")
|
| 670 |
+
print_log(f" Total tracks: {total_tracks}")
|
| 671 |
+
if success_count > 0:
|
| 672 |
+
avg_tracks = total_tracks / success_count
|
| 673 |
+
print_log(f" Average tracks per song: {avg_tracks:.2f}")
|
| 674 |
+
|
| 675 |
+
# Time Statistics
|
| 676 |
+
print_log(f"\n[Time Statistics]")
|
| 677 |
+
print_log(f" ├── Submission phase: {submit_phase_time:.1f} seconds")
|
| 678 |
+
print_log(f" │ ├── Actual request time: {total_request_time:.2f} seconds")
|
| 679 |
+
print_log(f" │ └── Rate limit waiting: {rate_limit_stats['total_wait_time']:.2f} seconds")
|
| 680 |
+
print_log(f" ├── Generation waiting phase: {wait_phase_time:.1f} seconds")
|
| 681 |
+
|
| 682 |
+
if successful_results:
|
| 683 |
+
wait_times = [r.get('wait_time', 0) for r in successful_results if 'wait_time' in r]
|
| 684 |
+
download_times = [r.get('download_time', 0) for r in successful_results if 'download_time' in r]
|
| 685 |
+
|
| 686 |
+
if wait_times:
|
| 687 |
+
avg_wait = sum(wait_times) / len(wait_times)
|
| 688 |
+
min_wait = min(wait_times)
|
| 689 |
+
max_wait = max(wait_times)
|
| 690 |
+
print_log(f" │ ├── Average wait time: {avg_wait:.1f} seconds/song")
|
| 691 |
+
print_log(f" │ ├── Fastest: {min_wait:.1f} seconds")
|
| 692 |
+
print_log(f" │ └── Slowest: {max_wait:.1f} seconds")
|
| 693 |
+
|
| 694 |
+
if download_times:
|
| 695 |
+
total_download_time = sum(download_times)
|
| 696 |
+
avg_download = total_download_time / len(download_times)
|
| 697 |
+
print_log(f" ├── Download phase: {total_download_time:.1f} seconds")
|
| 698 |
+
print_log(f" │ └── Average download time: {avg_download:.2f} seconds/song")
|
| 699 |
+
|
| 700 |
+
print_log(f" └── Total time: {overall_time:.1f} seconds ({overall_time/60:.1f} minutes)")
|
| 701 |
+
|
| 702 |
+
# Single Song Generation Statistics
|
| 703 |
+
if successful_results:
|
| 704 |
+
total_times = [r.get('total_time', 0) for r in successful_results if 'total_time' in r]
|
| 705 |
+
if total_times:
|
| 706 |
+
print_log(f"\n[Single Song Generation Statistics]")
|
| 707 |
+
avg_time = sum(total_times) / len(total_times)
|
| 708 |
+
min_time = min(total_times)
|
| 709 |
+
max_time = max(total_times)
|
| 710 |
+
print_log(f" Average total time per song: {avg_time:.1f} seconds")
|
| 711 |
+
print_log(f" Fastest generation: {min_time:.1f} seconds")
|
| 712 |
+
print_log(f" Slowest generation: {max_time:.1f} seconds")
|
| 713 |
+
|
| 714 |
+
# Download Statistics
|
| 715 |
+
total_download_bytes = sum(r.get('download_bytes', 0) for r in successful_results)
|
| 716 |
+
total_download_count = sum(r.get('download_count', 0) for r in successful_results)
|
| 717 |
+
|
| 718 |
+
if total_download_bytes > 0:
|
| 719 |
+
print_log(f"\n[Download Statistics]")
|
| 720 |
+
print_log(f" Total download: {format_bytes(total_download_bytes)}")
|
| 721 |
+
print_log(f" Number of files: {total_download_count}")
|
| 722 |
+
print_log(f" Average file size: {format_bytes(total_download_bytes / total_download_count)}")
|
| 723 |
+
|
| 724 |
+
download_speeds = [r.get('avg_download_speed', 0) for r in successful_results if r.get('avg_download_speed', 0) > 0]
|
| 725 |
+
if download_speeds:
|
| 726 |
+
avg_speed = sum(download_speeds) / len(download_speeds)
|
| 727 |
+
print_log(f" Average download speed: {format_speed(avg_speed)}")
|
| 728 |
+
|
| 729 |
+
# Polling Statistics
|
| 730 |
+
poll_counts = [r.get('poll_count', 0) for r in successful_results if 'poll_count' in r]
|
| 731 |
+
if poll_counts:
|
| 732 |
+
total_polls = sum(poll_counts)
|
| 733 |
+
avg_polls = total_polls / len(poll_counts)
|
| 734 |
+
print_log(f"\n[Polling Statistics]")
|
| 735 |
+
print_log(f" Total polling count: {total_polls}")
|
| 736 |
+
print_log(f" Average polling per song: {avg_polls:.1f}")
|
| 737 |
+
|
| 738 |
+
# Efficiency Analysis
|
| 739 |
+
print_log(f"\n[Efficiency Analysis]")
|
| 740 |
+
if success_count > 0:
|
| 741 |
+
throughput = success_count / (overall_time / 60)
|
| 742 |
+
print_log(f" Actual throughput: {throughput:.2f} songs/minute")
|
| 743 |
+
|
| 744 |
+
# Theoretical fastest time (assuming no rate limit)
|
| 745 |
+
if wait_times:
|
| 746 |
+
ideal_time = submit_phase_time - rate_limit_stats['total_wait_time'] + max(wait_times)
|
| 747 |
+
efficiency = (ideal_time / overall_time) * 100
|
| 748 |
+
print_log(f" Theoretical fastest time: {ideal_time:.1f} seconds")
|
| 749 |
+
print_log(f" Concurrency efficiency: {efficiency:.1f}%")
|
| 750 |
+
|
| 751 |
+
# Show failed songs
|
| 752 |
+
if fail_count > 0:
|
| 753 |
+
print_log("\n" + "=" * 70)
|
| 754 |
+
print_log("Failed Songs List")
|
| 755 |
+
print_log("=" * 70)
|
| 756 |
+
for r in sorted(final_results, key=lambda x: x.get('song_index', 0)):
|
| 757 |
+
if not r.get('success'):
|
| 758 |
+
song_id = r.get('song_id', r.get('song_index', 'Unknown'))
|
| 759 |
+
print_log(f" [{song_id}] {r.get('error', 'Unknown error')}")
|
| 760 |
+
|
| 761 |
+
print_log("\n" + "=" * 70)
|
| 762 |
+
print_log(f"All files saved to: {os.path.abspath(output_dir)}")
|
| 763 |
+
print_log("=" * 70)
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
if __name__ == '__main__':
|
| 767 |
+
main()
|
| 768 |
+
|
baseline_generate/yue/__pycache__/infer_batch.cpython-311.pyc
ADDED
|
Binary file (58.8 kB). View file
|
|
|
baseline_generate/yue/batch.sh
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Batch music generation example script
|
| 4 |
+
# Requires yue source repository, main code change is infer_batch.py replacing infer.py
|
| 5 |
+
|
| 6 |
+
# Get absolute path of script directory (to avoid filesystem mount issues)
|
| 7 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd 2>/dev/null || echo "xxx/YuE/inference")"
|
| 8 |
+
|
| 9 |
+
# Change to script directory (if possible, otherwise use absolute path)
|
| 10 |
+
cd "$SCRIPT_DIR" 2>/dev/null || true
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Set HuggingFace mirror
|
| 14 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 15 |
+
|
| 16 |
+
# Set PyTorch CUDA memory management optimization
|
| 17 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 18 |
+
|
| 19 |
+
export CUDA_VISIBLE_DEVICES=1
|
| 20 |
+
|
| 21 |
+
# Set JSONL file path
|
| 22 |
+
JSONL_PATH=""
|
| 23 |
+
|
| 24 |
+
# Set output directory
|
| 25 |
+
OUTPUT_DIR=""
|
| 26 |
+
|
| 27 |
+
# Set processing range (optional)
|
| 28 |
+
# Example: only process first 5 songs
|
| 29 |
+
START_IDX=0
|
| 30 |
+
END_IDX=-1
|
| 31 |
+
|
| 32 |
+
# Set generation parameters
|
| 33 |
+
MAX_NEW_TOKENS=3500
|
| 34 |
+
REPETITION_PENALTY=1.1
|
| 35 |
+
RUN_N_SEGMENTS=24
|
| 36 |
+
STAGE2_BATCH_SIZE=16
|
| 37 |
+
CUDA_IDX=0
|
| 38 |
+
SEED=42
|
| 39 |
+
NO_SAMPLE=0
|
| 40 |
+
|
| 41 |
+
# Run batch generation (using absolute path)
|
| 42 |
+
python "$SCRIPT_DIR/infer_batch.py" \
|
| 43 |
+
--jsonl_path "$JSONL_PATH" \
|
| 44 |
+
--output_dir "$OUTPUT_DIR" \
|
| 45 |
+
--start_idx $START_IDX \
|
| 46 |
+
--end_idx $END_IDX \
|
| 47 |
+
--max_new_tokens $MAX_NEW_TOKENS \
|
| 48 |
+
--repetition_penalty $REPETITION_PENALTY \
|
| 49 |
+
--run_n_segments $RUN_N_SEGMENTS \
|
| 50 |
+
--stage2_batch_size $STAGE2_BATCH_SIZE \
|
| 51 |
+
--cuda_idx $CUDA_IDX \
|
| 52 |
+
--seed $SEED \
|
| 53 |
+
--rescale \
|
| 54 |
+
$( [ "$NO_SAMPLE" -eq 1 ] && echo "--no_sample" )
|
| 55 |
+
|
baseline_generate/yue/codecmanipulator.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import numpy as np
|
| 3 |
+
import einops
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class CodecManipulator(object):
|
| 7 |
+
r"""
|
| 8 |
+
**mm tokenizer v0.1**
|
| 9 |
+
see codeclm/hf/mm_tokenizer_v0.1_hf/id2vocab.json
|
| 10 |
+
|
| 11 |
+
text tokens:
|
| 12 |
+
llama tokenizer 0~31999
|
| 13 |
+
|
| 14 |
+
special tokens: "32000": "<EOD>", "32001": "<SOA>", "32002": "<EOA>", "32003": "<SOI>", "32004": "<EOI>", "32005": "<SOV>", "32006": "<EOV>", "32007": "<s_local>", "32008": "<e_local>", "32009": "<s_global>", "32010": "<e_global>", "32011": "<semantic>", "32012": "<acoustic>", "32013": "<low_level>", "32014": "<dac_16k>", "32015": "<dac_44k>", "32016": "<xcodec>", "32017": "<placeholder>", "32018": "<semantic_mert>", "32019": "<semantic_hubert>", "32020": "<visual>", "32021": "<semanticodec>"
|
| 15 |
+
|
| 16 |
+
mm tokens:
|
| 17 |
+
dac_16k: 4 codebook, 1024 vocab, 32022 - 36117
|
| 18 |
+
dac_44k: 9 codebook, 1024 vocab, 36118 - 45333
|
| 19 |
+
xcodec: 12 codebook, 1024 vocab, 45334 - 57621
|
| 20 |
+
semantic mert: 1024, 57622 - 58645
|
| 21 |
+
semantic hubert: 512, 58646 - 59157
|
| 22 |
+
visual: 64000, not included in v0.1
|
| 23 |
+
semanticodec 100tps 16384: semantic=16384, 59158 - 75541, acoustic=8192, 75542 - 83733
|
| 24 |
+
"""
|
| 25 |
+
def __init__(self, codec_type, quantizer_begin=None, n_quantizer=None, teacher_forcing=False, data_feature="codec"):
|
| 26 |
+
self.codec_type = codec_type
|
| 27 |
+
self.mm_v0_2_cfg = {
|
| 28 |
+
"dac16k": {"codebook_size": 1024, "num_codebooks": 4, "global_offset": 32022, "sep": ["<dac_16k>"], "fps": 50},
|
| 29 |
+
"dac44k": {"codebook_size": 1024, "num_codebooks": 9, "global_offset": 36118, "sep": ["<dac_44k>"]},
|
| 30 |
+
"xcodec": {"codebook_size": 1024, "num_codebooks": 12, "global_offset": 45334, "sep": ["<xcodec>"], "fps": 50},
|
| 31 |
+
"mert": {"codebook_size": 1024, "global_offset": 57622, "sep": ["<semantic_mert>"]},
|
| 32 |
+
"hubert": {"codebook_size": 512, "global_offset": 58646, "sep": ["<semantic_hubert>"]},
|
| 33 |
+
"semantic/s": {"codebook_size": 16384, "num_codebooks": 1, "global_offset": 59158, "sep": ["<semanticodec>", "<semantic>"]},
|
| 34 |
+
"semantic/a": {"codebook_size": 8192, "num_codebooks": 1, "global_offset": 75542, "sep": ["<semanticodec>", "<acoustic>"]},
|
| 35 |
+
"semanticodec": {"codebook_size": [16384, 8192], "num_codebooks": 2, "global_offset": 59158, "sep": ["<semanticodec>"], "fps": 50},
|
| 36 |
+
"special_tokens": {
|
| 37 |
+
'<EOD>': 32000, '<SOA>': 32001, '<EOA>': 32002, '<SOI>': 32003, '<EOI>': 32004, '<SOV>': 32005, '<EOV>': 32006, '<s_local>': 32007, '<e_local>': 32008, '<s_global>': 32009, '<e_global>': 32010, '<semantic>': 32011, '<acoustic>': 32012, '<stage_1>': 32013, '<dac_16k>': 32014, '<dac_44k>': 32015, '<xcodec>': 32016, '<stage_2>': 32017, '<semantic_mert>': 32018, '<semantic_hubert>': 32019, '<visual>': 32020, '<semanticodec>': 32021
|
| 38 |
+
},
|
| 39 |
+
"metadata": {
|
| 40 |
+
"len": 83734,
|
| 41 |
+
"text_range": [0, 31999],
|
| 42 |
+
"special_range": [32000, 32021],
|
| 43 |
+
"mm_range": [32022, 83733]
|
| 44 |
+
},
|
| 45 |
+
"codec_range": {
|
| 46 |
+
"dac16k": [32022, 36117],
|
| 47 |
+
"dac44k": [36118, 45333],
|
| 48 |
+
"xcodec": [45334, 57621],
|
| 49 |
+
# "hifi16k": [53526, 57621],
|
| 50 |
+
"mert": [57622, 58645],
|
| 51 |
+
"hubert": [58646, 59157],
|
| 52 |
+
"semantic/s": [59158, 75541],
|
| 53 |
+
"semantic/a": [75542, 83733],
|
| 54 |
+
"semanticodec": [59158, 83733]
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
self.sep = self.mm_v0_2_cfg[self.codec_type]["sep"]
|
| 58 |
+
self.sep_ids = [self.mm_v0_2_cfg["special_tokens"][s] for s in self.sep]
|
| 59 |
+
self.codebook_size = self.mm_v0_2_cfg[self.codec_type]["codebook_size"]
|
| 60 |
+
self.num_codebooks = self.mm_v0_2_cfg[self.codec_type]["num_codebooks"]
|
| 61 |
+
self.global_offset = self.mm_v0_2_cfg[self.codec_type]["global_offset"]
|
| 62 |
+
self.fps = self.mm_v0_2_cfg[self.codec_type]["fps"] if "fps" in self.mm_v0_2_cfg[self.codec_type] else None
|
| 63 |
+
|
| 64 |
+
self.quantizer_begin = quantizer_begin if quantizer_begin is not None else 0
|
| 65 |
+
self.n_quantizer = n_quantizer if n_quantizer is not None else self.num_codebooks
|
| 66 |
+
self.teacher_forcing = teacher_forcing
|
| 67 |
+
self.data_feature = data_feature
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def offset_tok_ids(self, x, global_offset=0, codebook_size=2048, num_codebooks=4):
|
| 71 |
+
"""
|
| 72 |
+
x: (K, T)
|
| 73 |
+
"""
|
| 74 |
+
if isinstance(codebook_size, int):
|
| 75 |
+
assert x.max() < codebook_size, f"max(x)={x.max()}, codebook_size={codebook_size}"
|
| 76 |
+
elif isinstance(codebook_size, list):
|
| 77 |
+
for i, cs in enumerate(codebook_size):
|
| 78 |
+
assert x[i].max() < cs, f"max(x)={x[i].max()}, codebook_size={cs}, layer_id={i}"
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f"codebook_size={codebook_size}")
|
| 81 |
+
assert x.min() >= 0, f"min(x)={x.min()}"
|
| 82 |
+
assert x.shape[0] == num_codebooks or x.shape[0] == self.n_quantizer, \
|
| 83 |
+
f"x.shape[0]={x.shape[0]}, num_codebooks={num_codebooks}, n_quantizer={self.n_quantizer}"
|
| 84 |
+
|
| 85 |
+
_x = x.copy()
|
| 86 |
+
_x = _x.astype(np.uint32)
|
| 87 |
+
cum_offset = 0
|
| 88 |
+
quantizer_begin = self.quantizer_begin
|
| 89 |
+
quantizer_end = quantizer_begin+self.n_quantizer
|
| 90 |
+
for k in range(self.quantizer_begin, quantizer_end): # k: quantizer_begin to quantizer_end - 1
|
| 91 |
+
if isinstance(codebook_size, int):
|
| 92 |
+
_x[k] += global_offset + k * codebook_size
|
| 93 |
+
elif isinstance(codebook_size, list):
|
| 94 |
+
_x[k] += global_offset + cum_offset
|
| 95 |
+
cum_offset += codebook_size[k]
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError(f"codebook_size={codebook_size}")
|
| 98 |
+
return _x[quantizer_begin:quantizer_end]
|
| 99 |
+
|
| 100 |
+
def unoffset_tok_ids(self, x, global_offset=0, codebook_size=2048, num_codebooks=4):
|
| 101 |
+
"""
|
| 102 |
+
x: (K, T)
|
| 103 |
+
"""
|
| 104 |
+
if isinstance(codebook_size, int):
|
| 105 |
+
assert x.max() < global_offset + codebook_size * num_codebooks, f"max(x)={x.max()}, codebook_size={codebook_size}"
|
| 106 |
+
elif isinstance(codebook_size, list):
|
| 107 |
+
assert x.max() < global_offset + sum(codebook_size), f"max(x)={x.max()}, codebook_size={codebook_size}"
|
| 108 |
+
assert x.min() >= global_offset, f"min(x)={x.min()}, global_offset={global_offset}"
|
| 109 |
+
assert x.shape[0] == num_codebooks or x.shape[0] == self.n_quantizer, \
|
| 110 |
+
f"x.shape[0]={x.shape[0]}, num_codebooks={num_codebooks}, n_quantizer={self.n_quantizer}"
|
| 111 |
+
|
| 112 |
+
_x = x.copy()
|
| 113 |
+
_x = _x.astype(np.uint32)
|
| 114 |
+
cum_offset = 0
|
| 115 |
+
quantizer_begin = self.quantizer_begin
|
| 116 |
+
quantizer_end = quantizer_begin+self.n_quantizer
|
| 117 |
+
for k in range(quantizer_begin, quantizer_end):
|
| 118 |
+
if isinstance(codebook_size, int):
|
| 119 |
+
_x[k-quantizer_begin] -= global_offset + k * codebook_size
|
| 120 |
+
elif isinstance(codebook_size, list):
|
| 121 |
+
_x[k-quantizer_begin] -= global_offset + cum_offset
|
| 122 |
+
cum_offset += codebook_size[k]
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError(f"codebook_size={codebook_size}")
|
| 125 |
+
return _x
|
| 126 |
+
|
| 127 |
+
def flatten(self, x):
|
| 128 |
+
if len(x.shape) > 2:
|
| 129 |
+
x = x.squeeze()
|
| 130 |
+
assert x.shape[0] == self.num_codebooks or x.shape[0] == self.n_quantizer, \
|
| 131 |
+
f"x.shape[0]={x.shape[0]}, num_codebooks={self.num_codebooks}, n_quantizer={self.n_quantizer}"
|
| 132 |
+
return einops.rearrange(x, 'K T -> (T K)')
|
| 133 |
+
|
| 134 |
+
def unflatten(self, x, n_quantizer=None):
|
| 135 |
+
if x.ndim > 1 and x.shape[0] == 1:
|
| 136 |
+
x = x.squeeze(0)
|
| 137 |
+
assert len(x.shape) == 1
|
| 138 |
+
assert x.shape[0] % self.num_codebooks == 0 or x.shape[0] % self.n_quantizer == 0, \
|
| 139 |
+
f"x.shape[0]={x.shape[0]}, num_codebooks={self.num_codebooks}, n_quantizer={self.n_quantizer}"
|
| 140 |
+
if n_quantizer!=self.num_codebooks:
|
| 141 |
+
return einops.rearrange(x, '(T K) -> K T', K=n_quantizer)
|
| 142 |
+
return einops.rearrange(x, '(T K) -> K T', K=self.num_codebooks)
|
| 143 |
+
|
| 144 |
+
# def check_codec_type_from_path(self, path):
|
| 145 |
+
# if self.codec_type == "hifi16k":
|
| 146 |
+
# assert "academicodec_hifi_16k_320d_large_uni" in path
|
| 147 |
+
|
| 148 |
+
def get_codec_type_from_range(self, ids):
|
| 149 |
+
ids_range = [ids.min(), ids.max()]
|
| 150 |
+
codec_range = self.mm_v0_2_cfg["codec_range"]
|
| 151 |
+
for codec_type, r in codec_range.items():
|
| 152 |
+
if ids_range[0] >= r[0] and ids_range[1] <= r[1]:
|
| 153 |
+
return codec_type
|
| 154 |
+
raise ValueError(f"ids_range={ids_range}, codec_range={codec_range}")
|
| 155 |
+
|
| 156 |
+
def npy2ids(self, npy):
|
| 157 |
+
if isinstance(npy, str):
|
| 158 |
+
data = np.load(npy)
|
| 159 |
+
elif isinstance(npy, np.ndarray):
|
| 160 |
+
data = npy
|
| 161 |
+
else:
|
| 162 |
+
raise ValueError(f"not supported type: {type(npy)}")
|
| 163 |
+
# data = data.squeeze()
|
| 164 |
+
|
| 165 |
+
assert len(data.shape)==2, f'data shape: {data.shape} is not (n_codebook, seq_len)'
|
| 166 |
+
data = self.offset_tok_ids(
|
| 167 |
+
data,
|
| 168 |
+
global_offset=self.global_offset,
|
| 169 |
+
codebook_size=self.codebook_size,
|
| 170 |
+
num_codebooks=self.num_codebooks,
|
| 171 |
+
)
|
| 172 |
+
data = self.flatten(data)
|
| 173 |
+
codec_range = self.get_codec_type_from_range(data)
|
| 174 |
+
assert codec_range == self.codec_type, f"get_codec_type_from_range(data)={codec_range}, self.codec_type={self.codec_type}"
|
| 175 |
+
data = data.tolist()
|
| 176 |
+
return data
|
| 177 |
+
|
| 178 |
+
def ids2npy(self, token_ids):
|
| 179 |
+
# make sure token_ids starts with codebook 0
|
| 180 |
+
if isinstance(self.codebook_size, int):
|
| 181 |
+
codebook_0_range = (self.global_offset + self.quantizer_begin*self.codebook_size, self.global_offset + (self.quantizer_begin+1)*self.codebook_size)
|
| 182 |
+
elif isinstance(self.codebook_size, list):
|
| 183 |
+
codebook_0_range = (self.global_offset, self.global_offset + self.codebook_size[0])
|
| 184 |
+
assert token_ids[0] >= codebook_0_range[0] \
|
| 185 |
+
and token_ids[0] < codebook_0_range[1], f"token_ids[0]={token_ids[self.quantizer_begin]}, codebook_0_range={codebook_0_range}"
|
| 186 |
+
data = np.array(token_ids)
|
| 187 |
+
data = self.unflatten(data, n_quantizer=self.n_quantizer)
|
| 188 |
+
data = self.unoffset_tok_ids(
|
| 189 |
+
data,
|
| 190 |
+
global_offset=self.global_offset,
|
| 191 |
+
codebook_size=self.codebook_size,
|
| 192 |
+
num_codebooks=self.num_codebooks,
|
| 193 |
+
)
|
| 194 |
+
return data
|
| 195 |
+
|
| 196 |
+
def npy_to_json_str(self, npy_path):
|
| 197 |
+
data = self.npy2ids(npy_path)
|
| 198 |
+
return json.dumps({"text": data, "src": npy_path, "codec": self.codec_type})
|
| 199 |
+
|
| 200 |
+
def sep(self):
|
| 201 |
+
return ''.join(self.sep)
|
| 202 |
+
|
| 203 |
+
def sep_ids(self):
|
| 204 |
+
return self.sep_ids
|
baseline_generate/yue/infer_batch.py
ADDED
|
@@ -0,0 +1,904 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
|
| 4 |
+
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
|
| 5 |
+
import re
|
| 6 |
+
import random
|
| 7 |
+
import uuid
|
| 8 |
+
import copy
|
| 9 |
+
import json
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from collections import Counter
|
| 12 |
+
import argparse
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torchaudio
|
| 16 |
+
from torchaudio.transforms import Resample
|
| 17 |
+
import soundfile as sf
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
|
| 20 |
+
from omegaconf import OmegaConf
|
| 21 |
+
from codecmanipulator import CodecManipulator
|
| 22 |
+
from mmtokenizer import _MMSentencePieceTokenizer
|
| 23 |
+
from models.soundstream_hubert_new import SoundStream
|
| 24 |
+
from vocoder import build_codec_model, process_audio
|
| 25 |
+
from post_process_audio import replace_low_freq_with_energy_matched
|
| 26 |
+
|
| 27 |
+
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
| 28 |
+
|
| 29 |
+
parser = argparse.ArgumentParser()
|
| 30 |
+
# Model Configuration:
|
| 31 |
+
parser.add_argument("--stage1_model", type=str, default="m-a-p/YuE-s1-7B-anneal-en-cot", help="The model checkpoint path or identifier for the Stage 1 model.")
|
| 32 |
+
parser.add_argument("--stage2_model", type=str, default="m-a-p/YuE-s2-1B-general", help="The model checkpoint path or identifier for the Stage 2 model.")
|
| 33 |
+
parser.add_argument("--max_new_tokens", type=int, default=3000, help="The maximum number of new tokens to generate in one pass during text generation.")
|
| 34 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.1, help="repetition_penalty ranges from 1.0 to 2.0 (or higher in some cases). It controls the diversity and coherence of the audio tokens generated. The higher the value, the greater the discouragement of repetition. Setting value to 1.0 means no penalty.")
|
| 35 |
+
parser.add_argument("--run_n_segments", type=int, default=2, help="The number of segments to process during generation. Each segment is ~30s (with default max_new_tokens=3000). For example: 2=~1min, 6=~3min, 8=~4min.")
|
| 36 |
+
parser.add_argument("--stage2_batch_size", type=int, default=4, help="The batch size used in Stage 2 inference.")
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--no_sample",
|
| 39 |
+
action="store_true",
|
| 40 |
+
help="If set, disable sampling in Stage 1 generation (i.e., use deterministic decoding). When enabled, top_p/temperature will be ignored.",
|
| 41 |
+
)
|
| 42 |
+
# Prompt - Batch processing parameters
|
| 43 |
+
parser.add_argument("--jsonl_path", type=str, required=True, help="The file path to a JSONL file containing genre and lyrics for batch processing.")
|
| 44 |
+
parser.add_argument("--start_idx", type=int, default=0, help="Start index in the JSONL file for batch processing.")
|
| 45 |
+
parser.add_argument("--end_idx", type=int, default=-1, help="End index in the JSONL file for batch processing. -1 means process all.")
|
| 46 |
+
parser.add_argument("--use_audio_prompt", action="store_true", help="If set, the model will use an audio file as a prompt during generation. The audio file should be specified using --audio_prompt_path.")
|
| 47 |
+
parser.add_argument("--audio_prompt_path", type=str, default="", help="The file path to an audio file to use as a reference prompt when --use_audio_prompt is enabled.")
|
| 48 |
+
parser.add_argument("--prompt_start_time", type=float, default=0.0, help="The start time in seconds to extract the audio prompt from the given audio file.")
|
| 49 |
+
parser.add_argument("--prompt_end_time", type=float, default=30.0, help="The end time in seconds to extract the audio prompt from the given audio file.")
|
| 50 |
+
parser.add_argument("--use_dual_tracks_prompt", action="store_true", help="If set, the model will use dual tracks as a prompt during generation. The vocal and instrumental files should be specified using --vocal_track_prompt_path and --instrumental_track_prompt_path.")
|
| 51 |
+
parser.add_argument("--vocal_track_prompt_path", type=str, default="", help="The file path to a vocal track file to use as a reference prompt when --use_dual_tracks_prompt is enabled.")
|
| 52 |
+
parser.add_argument("--instrumental_track_prompt_path", type=str, default="", help="The file path to an instrumental track file to use as a reference prompt when --use_dual_tracks_prompt is enabled.")
|
| 53 |
+
# Output
|
| 54 |
+
parser.add_argument("--output_dir", type=str, default="./output", help="The directory where generated outputs will be saved.")
|
| 55 |
+
parser.add_argument("--keep_intermediate", action="store_true", help="If set, intermediate outputs will be saved during processing.")
|
| 56 |
+
parser.add_argument("--disable_offload_model", action="store_true", help="If set, the model will not be offloaded from the GPU to CPU after Stage 1 inference.")
|
| 57 |
+
parser.add_argument("--cuda_idx", type=int, default=0)
|
| 58 |
+
parser.add_argument("--seed", type=int, default=42, help="An integer value to reproduce generation.")
|
| 59 |
+
# Config for xcodec and upsampler
|
| 60 |
+
parser.add_argument('--basic_model_config', default='./xcodec_mini_infer/final_ckpt/config.yaml', help='YAML files for xcodec configurations.')
|
| 61 |
+
parser.add_argument('--resume_path', default='./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth', help='Path to the xcodec checkpoint.')
|
| 62 |
+
parser.add_argument('--config_path', type=str, default='./xcodec_mini_infer/decoders/config.yaml', help='Path to Vocos config file.')
|
| 63 |
+
parser.add_argument('--vocal_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.')
|
| 64 |
+
parser.add_argument('--inst_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.')
|
| 65 |
+
parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.')
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
args = parser.parse_args()
|
| 69 |
+
if args.use_audio_prompt and not args.audio_prompt_path:
|
| 70 |
+
raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
|
| 71 |
+
if args.use_dual_tracks_prompt and not args.vocal_track_prompt_path and not args.instrumental_track_prompt_path:
|
| 72 |
+
raise FileNotFoundError("Please offer dual tracks prompt filepath using '--vocal_track_prompt_path' and '--inst_decoder_path', when you enable '--use_dual_tracks_prompt'!")
|
| 73 |
+
|
| 74 |
+
stage1_model = args.stage1_model
|
| 75 |
+
stage2_model = args.stage2_model
|
| 76 |
+
cuda_idx = args.cuda_idx
|
| 77 |
+
max_new_tokens = args.max_new_tokens
|
| 78 |
+
do_sample_stage1 = (not args.no_sample)
|
| 79 |
+
|
| 80 |
+
def seed_everything(seed=42):
|
| 81 |
+
random.seed(seed)
|
| 82 |
+
np.random.seed(seed)
|
| 83 |
+
torch.manual_seed(seed)
|
| 84 |
+
torch.cuda.manual_seed_all(seed)
|
| 85 |
+
torch.backends.cudnn.deterministic = True
|
| 86 |
+
torch.backends.cudnn.benchmark = False
|
| 87 |
+
|
| 88 |
+
seed_everything(args.seed)
|
| 89 |
+
|
| 90 |
+
# Read JSONL file
|
| 91 |
+
print(f"Reading JSONL file: {args.jsonl_path}")
|
| 92 |
+
music_data_list = []
|
| 93 |
+
with open(args.jsonl_path, 'r', encoding='utf-8') as f:
|
| 94 |
+
for line in f:
|
| 95 |
+
if line.strip():
|
| 96 |
+
music_data_list.append(json.loads(line))
|
| 97 |
+
|
| 98 |
+
# Determine processing range
|
| 99 |
+
start_idx = args.start_idx
|
| 100 |
+
end_idx = len(music_data_list) if args.end_idx == -1 else min(args.end_idx, len(music_data_list))
|
| 101 |
+
music_data_list = music_data_list[start_idx:end_idx]
|
| 102 |
+
print(f"Total {len(music_data_list)} songs to generate (indices {start_idx} to {end_idx-1})")
|
| 103 |
+
|
| 104 |
+
# Detect processed songs - check completion status of each stage
|
| 105 |
+
def check_song_status(song_idx, output_dir):
|
| 106 |
+
"""
|
| 107 |
+
Check song processing status
|
| 108 |
+
Returns: (stage1_done, stage2_done, stage3_done, song_dir, stage1_output_set, stage2_output_dir)
|
| 109 |
+
"""
|
| 110 |
+
if not os.path.exists(output_dir):
|
| 111 |
+
return False, False, False, None, None, None
|
| 112 |
+
|
| 113 |
+
# Find song directory (may have multiple, take the latest or first)
|
| 114 |
+
song_dirs = []
|
| 115 |
+
for item in os.listdir(output_dir):
|
| 116 |
+
if item.startswith('song_') and os.path.isdir(os.path.join(output_dir, item)):
|
| 117 |
+
try:
|
| 118 |
+
idx = int(item.split('_')[1])
|
| 119 |
+
if idx == song_idx:
|
| 120 |
+
song_dirs.append(os.path.join(output_dir, item))
|
| 121 |
+
except (ValueError, IndexError):
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
if not song_dirs:
|
| 125 |
+
return False, False, False, None, None, None
|
| 126 |
+
|
| 127 |
+
# Use the latest directory (sorted by modification time)
|
| 128 |
+
song_dir = max(song_dirs, key=lambda x: os.path.getmtime(x))
|
| 129 |
+
|
| 130 |
+
# Check Stage 1: whether stage1 directory has vtrack and itrack .npy files
|
| 131 |
+
stage1_dir = os.path.join(song_dir, "stage1")
|
| 132 |
+
stage1_done = False
|
| 133 |
+
stage1_output_set = []
|
| 134 |
+
if os.path.exists(stage1_dir):
|
| 135 |
+
stage1_files = [f for f in os.listdir(stage1_dir) if f.endswith('.npy')]
|
| 136 |
+
vtrack_files = [f for f in stage1_files if '_vtrack' in f]
|
| 137 |
+
itrack_files = [f for f in stage1_files if '_itrack' in f]
|
| 138 |
+
if vtrack_files and itrack_files:
|
| 139 |
+
stage1_done = True
|
| 140 |
+
# Build stage1_output_set
|
| 141 |
+
for f in vtrack_files + itrack_files:
|
| 142 |
+
stage1_output_set.append(os.path.join(stage1_dir, f))
|
| 143 |
+
|
| 144 |
+
# Check Stage 2: whether stage2 directory has corresponding .npy files
|
| 145 |
+
stage2_dir = os.path.join(song_dir, "stage2")
|
| 146 |
+
stage2_done = False
|
| 147 |
+
if stage1_done and os.path.exists(stage2_dir):
|
| 148 |
+
stage2_files = [f for f in os.listdir(stage2_dir) if f.endswith('.npy')]
|
| 149 |
+
# Check if all stage1 files have corresponding stage2 files
|
| 150 |
+
if stage1_output_set:
|
| 151 |
+
stage1_basenames = {os.path.basename(f) for f in stage1_output_set}
|
| 152 |
+
stage2_basenames = set(stage2_files)
|
| 153 |
+
if stage1_basenames.issubset(stage2_basenames):
|
| 154 |
+
stage2_done = True
|
| 155 |
+
|
| 156 |
+
# Check Stage 3: whether there is a final mixed file (in song_dir root directory)
|
| 157 |
+
stage3_done = False
|
| 158 |
+
for root, dirs, files in os.walk(song_dir):
|
| 159 |
+
if any(f.endswith('_mixed.mp3') for f in files):
|
| 160 |
+
stage3_done = True
|
| 161 |
+
break
|
| 162 |
+
|
| 163 |
+
return stage1_done, stage2_done, stage3_done, song_dir, stage1_output_set, stage2_dir
|
| 164 |
+
|
| 165 |
+
# Detect processing status of all songs
|
| 166 |
+
song_status_map = {} # {song_idx: (stage1_done, stage2_done, stage3_done, song_dir, stage1_output_set, stage2_output_dir)}
|
| 167 |
+
if os.path.exists(args.output_dir):
|
| 168 |
+
print(f"\nDetecting processed songs...")
|
| 169 |
+
for list_idx in range(len(music_data_list)):
|
| 170 |
+
song_idx = start_idx + list_idx
|
| 171 |
+
stage1_done, stage2_done, stage3_done, song_dir, stage1_output_set, stage2_output_dir = check_song_status(song_idx, args.output_dir)
|
| 172 |
+
if stage1_done or stage2_done or stage3_done:
|
| 173 |
+
song_status_map[song_idx] = (stage1_done, stage2_done, stage3_done, song_dir, stage1_output_set, stage2_output_dir)
|
| 174 |
+
|
| 175 |
+
if song_status_map:
|
| 176 |
+
fully_completed = [idx for idx, (s1, s2, s3, _, _, _) in song_status_map.items() if s3]
|
| 177 |
+
partial_completed = [idx for idx, (s1, s2, s3, _, _, _) in song_status_map.items() if not s3]
|
| 178 |
+
print(f"✓ Found {len(fully_completed)} fully completed songs: {sorted(fully_completed)}")
|
| 179 |
+
if partial_completed:
|
| 180 |
+
print(f"✓ Found {len(partial_completed)} partially completed songs: {sorted(partial_completed)}")
|
| 181 |
+
for idx in sorted(partial_completed):
|
| 182 |
+
s1, s2, s3, _, _, _ = song_status_map[idx]
|
| 183 |
+
status_parts = []
|
| 184 |
+
if s1: status_parts.append("Stage1")
|
| 185 |
+
if s2: status_parts.append("Stage2")
|
| 186 |
+
if s3: status_parts.append("Stage3")
|
| 187 |
+
print(f" Index {idx}: Completed {', '.join(status_parts)}")
|
| 188 |
+
remaining_count = len(music_data_list) - len(fully_completed)
|
| 189 |
+
print(f"✓ Will skip fully completed songs, {remaining_count} songs remaining to process")
|
| 190 |
+
else:
|
| 191 |
+
print(f"✓ No processed songs found, will start from the beginning")
|
| 192 |
+
else:
|
| 193 |
+
print(f"✓ Output directory does not exist, will start from the beginning")
|
| 194 |
+
|
| 195 |
+
# Load tokenizer and model
|
| 196 |
+
device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
|
| 197 |
+
mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
|
| 198 |
+
print("Loading Stage 1 model...")
|
| 199 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 200 |
+
stage1_model,
|
| 201 |
+
torch_dtype=torch.bfloat16,
|
| 202 |
+
attn_implementation="flash_attention_2", # Using flash_attention_2 for better performance
|
| 203 |
+
# device_map="auto",
|
| 204 |
+
)
|
| 205 |
+
# to device, if gpu is available
|
| 206 |
+
model.to(device)
|
| 207 |
+
model.eval()
|
| 208 |
+
|
| 209 |
+
if torch.__version__ >= "2.0.0":
|
| 210 |
+
try:
|
| 211 |
+
model = torch.compile(model)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
print(f"Warning: torch.compile not available: {e}")
|
| 214 |
+
|
| 215 |
+
codectool = CodecManipulator("xcodec", 0, 1)
|
| 216 |
+
codectool_stage2 = CodecManipulator("xcodec", 0, 8)
|
| 217 |
+
model_config = OmegaConf.load(args.basic_model_config)
|
| 218 |
+
codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
|
| 219 |
+
# Load checkpoint with weights_only=False to allow OmegaConf types
|
| 220 |
+
# Note: Only use this if you trust the checkpoint source
|
| 221 |
+
parameter_dict = torch.load(args.resume_path, map_location='cpu', weights_only=False)
|
| 222 |
+
codec_model.load_state_dict(parameter_dict['codec_model'])
|
| 223 |
+
codec_model.to(device)
|
| 224 |
+
codec_model.eval()
|
| 225 |
+
|
| 226 |
+
class BlockTokenRangeProcessor(LogitsProcessor):
|
| 227 |
+
def __init__(self, start_id, end_id):
|
| 228 |
+
self.blocked_token_ids = list(range(start_id, end_id))
|
| 229 |
+
|
| 230 |
+
def __call__(self, input_ids, scores):
|
| 231 |
+
scores[:, self.blocked_token_ids] = -float("inf")
|
| 232 |
+
return scores
|
| 233 |
+
|
| 234 |
+
def load_audio_mono(filepath, sampling_rate=16000):
|
| 235 |
+
audio, sr = torchaudio.load(filepath)
|
| 236 |
+
# Convert to mono
|
| 237 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
| 238 |
+
# Resample if needed
|
| 239 |
+
if sr != sampling_rate:
|
| 240 |
+
resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
|
| 241 |
+
audio = resampler(audio)
|
| 242 |
+
return audio
|
| 243 |
+
|
| 244 |
+
def encode_audio(codec_model, audio_prompt, device, target_bw=0.5):
|
| 245 |
+
if len(audio_prompt.shape) < 3:
|
| 246 |
+
audio_prompt.unsqueeze_(0)
|
| 247 |
+
with torch.no_grad():
|
| 248 |
+
raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=target_bw)
|
| 249 |
+
raw_codes = raw_codes.transpose(0, 1)
|
| 250 |
+
raw_codes = raw_codes.cpu().numpy().astype(np.int16)
|
| 251 |
+
return raw_codes
|
| 252 |
+
|
| 253 |
+
def split_lyrics(lyrics):
|
| 254 |
+
"""
|
| 255 |
+
Split lyrics by segments, following YuE official best practices:
|
| 256 |
+
|
| 257 |
+
Official requirements:
|
| 258 |
+
1. Lyrics should be segmented using structure tags: [verse], [chorus], [bridge], [outro], etc.
|
| 259 |
+
2. Each segment is separated by two newlines "\n\n"
|
| 260 |
+
3. Each segment is about 30 seconds (when --max_new_tokens 3000), don't put too many words
|
| 261 |
+
4. Avoid using [intro] tag (not very stable), recommend starting with [verse] or [chorus]
|
| 262 |
+
5. Supports multiple languages: English, Chinese, Cantonese, Japanese, Korean, etc.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
lyrics: Raw lyrics string
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
Structured lyrics segment list, each segment in [tag]\ncontent\n\n format
|
| 269 |
+
"""
|
| 270 |
+
# Regular expression: match [any tag] and its following content
|
| 271 |
+
# Supports: [Verse 1], [Pre-Chorus], [Chorus (Outro)] and other complex tags
|
| 272 |
+
pattern = r"\[([^\]]+)\](.*?)(?=\[|\Z)"
|
| 273 |
+
segments = re.findall(pattern, lyrics, re.DOTALL)
|
| 274 |
+
structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
|
| 275 |
+
return structured_lyrics
|
| 276 |
+
|
| 277 |
+
def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
|
| 278 |
+
folder_path = os.path.dirname(path)
|
| 279 |
+
if not os.path.exists(folder_path):
|
| 280 |
+
os.makedirs(folder_path)
|
| 281 |
+
limit = 0.99
|
| 282 |
+
max_val = wav.abs().max()
|
| 283 |
+
wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
|
| 284 |
+
torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
|
| 285 |
+
|
| 286 |
+
def stage2_generate(model, prompt, batch_size=16):
|
| 287 |
+
codec_ids = codectool.unflatten(prompt, n_quantizer=1)
|
| 288 |
+
codec_ids = codectool.offset_tok_ids(
|
| 289 |
+
codec_ids,
|
| 290 |
+
global_offset=codectool.global_offset,
|
| 291 |
+
codebook_size=codectool.codebook_size,
|
| 292 |
+
num_codebooks=codectool.num_codebooks,
|
| 293 |
+
).astype(np.int32)
|
| 294 |
+
|
| 295 |
+
# Prepare prompt_ids based on batch size or single input
|
| 296 |
+
if batch_size > 1:
|
| 297 |
+
codec_list = []
|
| 298 |
+
for i in range(batch_size):
|
| 299 |
+
idx_begin = i * 300
|
| 300 |
+
idx_end = (i + 1) * 300
|
| 301 |
+
codec_list.append(codec_ids[:, idx_begin:idx_end])
|
| 302 |
+
|
| 303 |
+
codec_ids = np.concatenate(codec_list, axis=0)
|
| 304 |
+
prompt_ids = np.concatenate(
|
| 305 |
+
[
|
| 306 |
+
np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
|
| 307 |
+
codec_ids,
|
| 308 |
+
np.tile([mmtokenizer.stage_2], (batch_size, 1)),
|
| 309 |
+
],
|
| 310 |
+
axis=1
|
| 311 |
+
)
|
| 312 |
+
else:
|
| 313 |
+
prompt_ids = np.concatenate([
|
| 314 |
+
np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
|
| 315 |
+
codec_ids.flatten(), # Flatten the 2D array to 1D
|
| 316 |
+
np.array([mmtokenizer.stage_2])
|
| 317 |
+
]).astype(np.int32)
|
| 318 |
+
prompt_ids = prompt_ids[np.newaxis, ...]
|
| 319 |
+
|
| 320 |
+
codec_ids = torch.as_tensor(codec_ids).to(device)
|
| 321 |
+
prompt_ids = torch.as_tensor(prompt_ids).to(device)
|
| 322 |
+
len_prompt = prompt_ids.shape[-1]
|
| 323 |
+
|
| 324 |
+
block_list = LogitsProcessorList([BlockTokenRangeProcessor(0, 46358), BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)])
|
| 325 |
+
|
| 326 |
+
# Teacher forcing generate loop
|
| 327 |
+
for frames_idx in range(codec_ids.shape[1]):
|
| 328 |
+
cb0 = codec_ids[:, frames_idx:frames_idx+1]
|
| 329 |
+
prompt_ids = torch.cat([prompt_ids, cb0], dim=1)
|
| 330 |
+
input_ids = prompt_ids
|
| 331 |
+
|
| 332 |
+
with torch.no_grad():
|
| 333 |
+
stage2_output = model.generate(input_ids=input_ids,
|
| 334 |
+
min_new_tokens=7,
|
| 335 |
+
max_new_tokens=7,
|
| 336 |
+
eos_token_id=mmtokenizer.eoa,
|
| 337 |
+
pad_token_id=mmtokenizer.eoa,
|
| 338 |
+
logits_processor=block_list,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
assert stage2_output.shape[1] - prompt_ids.shape[1] == 7, f"output new tokens={stage2_output.shape[1]-prompt_ids.shape[1]}"
|
| 342 |
+
prompt_ids = stage2_output
|
| 343 |
+
|
| 344 |
+
# Return output based on batch size
|
| 345 |
+
if batch_size > 1:
|
| 346 |
+
output = prompt_ids.cpu().numpy()[:, len_prompt:]
|
| 347 |
+
output_list = [output[i] for i in range(batch_size)]
|
| 348 |
+
output = np.concatenate(output_list, axis=0)
|
| 349 |
+
else:
|
| 350 |
+
output = prompt_ids[0].cpu().numpy()[len_prompt:]
|
| 351 |
+
|
| 352 |
+
return output
|
| 353 |
+
|
| 354 |
+
def sanitize_genres_for_filename(genres, max_length=80):
|
| 355 |
+
"""
|
| 356 |
+
Clean and truncate genres string for filename generation
|
| 357 |
+
Ensure filename is not too long (Linux filename limit is 255 bytes)
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
genres: Raw genres string
|
| 361 |
+
max_length: Maximum length of genres part (default 80, leaving space for other parameters)
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
Cleaned genres string
|
| 365 |
+
"""
|
| 366 |
+
if not genres:
|
| 367 |
+
return "Unknown"
|
| 368 |
+
|
| 369 |
+
# Clean unsafe characters
|
| 370 |
+
genres_clean = re.sub(r'[<>:"/\\|?*\x00-\x1f]', '_', genres)
|
| 371 |
+
genres_clean = genres_clean.strip('_').strip()
|
| 372 |
+
|
| 373 |
+
# If contains comma-separated tags, try to keep first few tags
|
| 374 |
+
if ',' in genres_clean:
|
| 375 |
+
tags = [tag.strip() for tag in genres_clean.split(',')]
|
| 376 |
+
# Try to keep first few tags until reaching length limit
|
| 377 |
+
result_tags = []
|
| 378 |
+
current_length = 0
|
| 379 |
+
for tag in tags:
|
| 380 |
+
if current_length + len(tag) + 1 <= max_length: # +1 for comma
|
| 381 |
+
result_tags.append(tag)
|
| 382 |
+
current_length += len(tag) + 1
|
| 383 |
+
else:
|
| 384 |
+
break
|
| 385 |
+
if result_tags:
|
| 386 |
+
genres_clean = ','.join(result_tags)
|
| 387 |
+
else:
|
| 388 |
+
# If first tag is too long, directly truncate
|
| 389 |
+
genres_clean = tags[0][:max_length] if tags else genres_clean[:max_length]
|
| 390 |
+
|
| 391 |
+
# If still too long, directly truncate
|
| 392 |
+
if len(genres_clean) > max_length:
|
| 393 |
+
genres_clean = genres_clean[:max_length]
|
| 394 |
+
|
| 395 |
+
# Replace spaces with hyphens (for consistency)
|
| 396 |
+
genres_clean = genres_clean.replace(' ', '-')
|
| 397 |
+
|
| 398 |
+
return genres_clean
|
| 399 |
+
|
| 400 |
+
def stage2_inference(model, stage1_output_set, stage2_output_dir, batch_size=4):
|
| 401 |
+
stage2_result = []
|
| 402 |
+
for i in tqdm(range(len(stage1_output_set)), desc="Stage 2 inference"):
|
| 403 |
+
output_filename = os.path.join(stage2_output_dir, os.path.basename(stage1_output_set[i]))
|
| 404 |
+
|
| 405 |
+
if os.path.exists(output_filename):
|
| 406 |
+
print(f'{output_filename} stage2 has done.')
|
| 407 |
+
stage2_result.append(output_filename)
|
| 408 |
+
continue
|
| 409 |
+
|
| 410 |
+
# Load the prompt
|
| 411 |
+
prompt = np.load(stage1_output_set[i]).astype(np.int32)
|
| 412 |
+
|
| 413 |
+
# Only accept 6s segments
|
| 414 |
+
output_duration = prompt.shape[-1] // 50 // 6 * 6
|
| 415 |
+
num_batch = output_duration // 6
|
| 416 |
+
|
| 417 |
+
if num_batch <= batch_size:
|
| 418 |
+
# If num_batch is less than or equal to batch_size, we can infer the entire prompt at once
|
| 419 |
+
output = stage2_generate(model, prompt[:, :output_duration*50], batch_size=num_batch)
|
| 420 |
+
else:
|
| 421 |
+
# If num_batch is greater than batch_size, process in chunks of batch_size
|
| 422 |
+
segments = []
|
| 423 |
+
num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
|
| 424 |
+
|
| 425 |
+
for seg in range(num_segments):
|
| 426 |
+
start_idx = seg * batch_size * 300
|
| 427 |
+
# Ensure the end_idx does not exceed the available length
|
| 428 |
+
end_idx = min((seg + 1) * batch_size * 300, output_duration*50) # Adjust the last segment
|
| 429 |
+
current_batch_size = batch_size if seg != num_segments-1 or num_batch % batch_size == 0 else num_batch % batch_size
|
| 430 |
+
segment = stage2_generate(
|
| 431 |
+
model,
|
| 432 |
+
prompt[:, start_idx:end_idx],
|
| 433 |
+
batch_size=current_batch_size
|
| 434 |
+
)
|
| 435 |
+
segments.append(segment)
|
| 436 |
+
|
| 437 |
+
# Concatenate all the segments
|
| 438 |
+
output = np.concatenate(segments, axis=0)
|
| 439 |
+
|
| 440 |
+
# Process the ending part of the prompt
|
| 441 |
+
if output_duration*50 != prompt.shape[-1]:
|
| 442 |
+
ending = stage2_generate(model, prompt[:, output_duration*50:], batch_size=1)
|
| 443 |
+
output = np.concatenate([output, ending], axis=0)
|
| 444 |
+
output = codectool_stage2.ids2npy(output)
|
| 445 |
+
|
| 446 |
+
# Fix invalid codes (a dirty solution, which may harm the quality of audio)
|
| 447 |
+
# We are trying to find better one
|
| 448 |
+
fixed_output = copy.deepcopy(output)
|
| 449 |
+
for i, line in enumerate(output):
|
| 450 |
+
for j, element in enumerate(line):
|
| 451 |
+
if element < 0 or element > 1023:
|
| 452 |
+
counter = Counter(line)
|
| 453 |
+
most_frequant = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
|
| 454 |
+
fixed_output[i, j] = most_frequant
|
| 455 |
+
# save output
|
| 456 |
+
np.save(output_filename, fixed_output)
|
| 457 |
+
stage2_result.append(output_filename)
|
| 458 |
+
return stage2_result
|
| 459 |
+
|
| 460 |
+
def process_one_song(music_data, song_idx, total_songs):
|
| 461 |
+
"""Process Stage 1 for a single song"""
|
| 462 |
+
|
| 463 |
+
# Compatible with genre and description fields
|
| 464 |
+
genres = music_data.get('genre') or music_data.get('description', '')
|
| 465 |
+
lyrics_raw = music_data['lyrics']
|
| 466 |
+
description = music_data.get('description', '')
|
| 467 |
+
|
| 468 |
+
print(f"Description: {description[:100]}...")
|
| 469 |
+
print(f"Genre tags: {genres}")
|
| 470 |
+
|
| 471 |
+
# ===== Print original lyrics =====
|
| 472 |
+
print("\n" + "="*60)
|
| 473 |
+
print("【Original Lyrics (lyrics_raw)】")
|
| 474 |
+
print("="*60)
|
| 475 |
+
print(lyrics_raw)
|
| 476 |
+
print("="*60 + "\n")
|
| 477 |
+
|
| 478 |
+
lyrics = split_lyrics(lyrics_raw)
|
| 479 |
+
|
| 480 |
+
# Validate lyrics format and give warnings (following official best practices)
|
| 481 |
+
print(f"Lyrics analysis: Identified {len(lyrics)} segments")
|
| 482 |
+
|
| 483 |
+
# ===== Print segmented lyrics =====
|
| 484 |
+
print("\n" + "="*60)
|
| 485 |
+
print("【Segmented Lyrics (lyrics)】")
|
| 486 |
+
print("="*60)
|
| 487 |
+
for i, seg in enumerate(lyrics):
|
| 488 |
+
tag = seg.split('\n')[0].strip()
|
| 489 |
+
# Check if unstable [intro] tag is used
|
| 490 |
+
if 'intro' in tag.lower():
|
| 491 |
+
print(f" ⚠️ Warning: Segment {i+1} uses {tag} tag, official recommendation is to avoid [intro], use [verse] or [chorus] instead")
|
| 492 |
+
else:
|
| 493 |
+
print(f" Segment {i+1}. {tag}")
|
| 494 |
+
# Print each segment's content (limit length)
|
| 495 |
+
content = seg.strip()
|
| 496 |
+
if len(content) > 150:
|
| 497 |
+
print(f" Content preview: {content[:150]}...")
|
| 498 |
+
else:
|
| 499 |
+
print(f" Content: {content}")
|
| 500 |
+
print()
|
| 501 |
+
print("="*60 + "\n")
|
| 502 |
+
|
| 503 |
+
# Create output directory for this song
|
| 504 |
+
random_id = uuid.uuid4()
|
| 505 |
+
song_output_dir = os.path.join(args.output_dir, f"song_{song_idx:04d}_{random_id}")
|
| 506 |
+
stage1_output_dir = os.path.join(song_output_dir, "stage1")
|
| 507 |
+
stage2_output_dir = os.path.join(song_output_dir, "stage2")
|
| 508 |
+
os.makedirs(stage1_output_dir, exist_ok=True)
|
| 509 |
+
os.makedirs(stage2_output_dir, exist_ok=True)
|
| 510 |
+
|
| 511 |
+
# Stage 1: Generate audio tokens
|
| 512 |
+
print("--- Stage 1: Generate audio tokens ---")
|
| 513 |
+
stage1_output_set = []
|
| 514 |
+
full_lyrics = "\n".join(lyrics)
|
| 515 |
+
prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
|
| 516 |
+
prompt_texts += lyrics
|
| 517 |
+
|
| 518 |
+
# ===== Print prompt texts passed to model =====
|
| 519 |
+
print("\n" + "="*60)
|
| 520 |
+
print("【Prompt Texts Passed to Model (prompt_texts)】")
|
| 521 |
+
print("="*60)
|
| 522 |
+
print(f"Total {len(prompt_texts)} prompts (first is full prompt, subsequent are segments)\n")
|
| 523 |
+
for i, pt in enumerate(prompt_texts):
|
| 524 |
+
if i == 0:
|
| 525 |
+
print(f"Prompt {i} [Full prompt header]:")
|
| 526 |
+
if len(pt) > 300:
|
| 527 |
+
print(f"{pt[:300]}...")
|
| 528 |
+
else:
|
| 529 |
+
print(pt)
|
| 530 |
+
else:
|
| 531 |
+
print(f"\nPrompt {i} [Segment {i}]:")
|
| 532 |
+
if len(pt) > 200:
|
| 533 |
+
print(f"{pt[:200]}...")
|
| 534 |
+
else:
|
| 535 |
+
print(pt)
|
| 536 |
+
print("="*60 + "\n")
|
| 537 |
+
|
| 538 |
+
output_seq = None
|
| 539 |
+
# Here is suggested decoding config
|
| 540 |
+
top_p = 0.93
|
| 541 |
+
temperature = 1.0
|
| 542 |
+
repetition_penalty = args.repetition_penalty
|
| 543 |
+
if not do_sample_stage1:
|
| 544 |
+
print("Note: --no_sample is enabled, Stage 1 will use deterministic decoding; top_p/temperature will be ignored.")
|
| 545 |
+
# special tokens
|
| 546 |
+
start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
|
| 547 |
+
end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
|
| 548 |
+
# Format text prompt
|
| 549 |
+
# +1 because prompt_texts[0] is the full prompt which will be skipped, so need len(lyrics)+1 to process all segments
|
| 550 |
+
run_n_segments = min(args.run_n_segments+1, len(lyrics)+1)
|
| 551 |
+
|
| 552 |
+
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments], desc="Stage1 inference")):
|
| 553 |
+
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
| 554 |
+
guidance_scale = 1.5 if i <=1 else 1.2
|
| 555 |
+
|
| 556 |
+
# ===== Print currently processing segment =====
|
| 557 |
+
if i == 0:
|
| 558 |
+
print(f"\n[Segment {i}] Skipped (full prompt header)")
|
| 559 |
+
else:
|
| 560 |
+
print(f"\n" + "-"*60)
|
| 561 |
+
print(f"[Processing segment {i}/{len(prompt_texts[:run_n_segments])-1}]")
|
| 562 |
+
print("-"*60)
|
| 563 |
+
tag_line = section_text.split('\n')[0] if '\n' in section_text else section_text[:50]
|
| 564 |
+
print(f"Segment tag: {tag_line}")
|
| 565 |
+
print(f"Segment content length: {len(section_text)} characters")
|
| 566 |
+
if len(section_text) > 200:
|
| 567 |
+
print(f"Segment content preview: {section_text[:200]}...")
|
| 568 |
+
else:
|
| 569 |
+
print(f"Segment content: {section_text}")
|
| 570 |
+
print("-"*60)
|
| 571 |
+
|
| 572 |
+
if i==0:
|
| 573 |
+
continue
|
| 574 |
+
if i==1:
|
| 575 |
+
if args.use_dual_tracks_prompt or args.use_audio_prompt:
|
| 576 |
+
if args.use_dual_tracks_prompt:
|
| 577 |
+
vocals_ids = load_audio_mono(args.vocal_track_prompt_path)
|
| 578 |
+
instrumental_ids = load_audio_mono(args.instrumental_track_prompt_path)
|
| 579 |
+
vocals_ids = encode_audio(codec_model, vocals_ids, device, target_bw=0.5)
|
| 580 |
+
instrumental_ids = encode_audio(codec_model, instrumental_ids, device, target_bw=0.5)
|
| 581 |
+
vocals_ids = codectool.npy2ids(vocals_ids[0])
|
| 582 |
+
instrumental_ids = codectool.npy2ids(instrumental_ids[0])
|
| 583 |
+
ids_segment_interleaved = rearrange([np.array(vocals_ids), np.array(instrumental_ids)], 'b n -> (n b)')
|
| 584 |
+
audio_prompt_codec = ids_segment_interleaved[int(args.prompt_start_time*50*2): int(args.prompt_end_time*50*2)]
|
| 585 |
+
audio_prompt_codec = audio_prompt_codec.tolist()
|
| 586 |
+
elif args.use_audio_prompt:
|
| 587 |
+
audio_prompt = load_audio_mono(args.audio_prompt_path)
|
| 588 |
+
raw_codes = encode_audio(codec_model, audio_prompt, device, target_bw=0.5)
|
| 589 |
+
# Format audio prompt
|
| 590 |
+
code_ids = codectool.npy2ids(raw_codes[0])
|
| 591 |
+
audio_prompt_codec = code_ids[int(args.prompt_start_time *50): int(args.prompt_end_time *50)] # 50 is tps of xcodec
|
| 592 |
+
audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
|
| 593 |
+
sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
|
| 594 |
+
head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
|
| 595 |
+
else:
|
| 596 |
+
head_id = mmtokenizer.tokenize(prompt_texts[0])
|
| 597 |
+
prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
| 598 |
+
else:
|
| 599 |
+
prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
|
| 600 |
+
|
| 601 |
+
prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
|
| 602 |
+
input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
|
| 603 |
+
# Use window slicing in case output sequence exceeds the context of model
|
| 604 |
+
max_context = 16384-max_new_tokens-1
|
| 605 |
+
if input_ids.shape[-1] > max_context:
|
| 606 |
+
print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
|
| 607 |
+
input_ids = input_ids[:, -(max_context):]
|
| 608 |
+
with torch.no_grad():
|
| 609 |
+
output_seq = model.generate(
|
| 610 |
+
input_ids=input_ids,
|
| 611 |
+
max_new_tokens=max_new_tokens,
|
| 612 |
+
min_new_tokens=100,
|
| 613 |
+
do_sample=do_sample_stage1,
|
| 614 |
+
top_p=top_p,
|
| 615 |
+
temperature=temperature,
|
| 616 |
+
repetition_penalty=repetition_penalty,
|
| 617 |
+
eos_token_id=mmtokenizer.eoa,
|
| 618 |
+
pad_token_id=mmtokenizer.eoa,
|
| 619 |
+
logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
|
| 620 |
+
guidance_scale=guidance_scale,
|
| 621 |
+
)
|
| 622 |
+
if output_seq[0][-1].item() != mmtokenizer.eoa:
|
| 623 |
+
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
| 624 |
+
output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
|
| 625 |
+
if i > 1:
|
| 626 |
+
raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
|
| 627 |
+
else:
|
| 628 |
+
raw_output = output_seq
|
| 629 |
+
|
| 630 |
+
# save raw output and check sanity
|
| 631 |
+
ids = raw_output[0].cpu().numpy()
|
| 632 |
+
soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
|
| 633 |
+
eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
|
| 634 |
+
if len(soa_idx)!=len(eoa_idx):
|
| 635 |
+
raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
|
| 636 |
+
|
| 637 |
+
vocals = []
|
| 638 |
+
instrumentals = []
|
| 639 |
+
range_begin = 1 if args.use_audio_prompt or args.use_dual_tracks_prompt else 0
|
| 640 |
+
for i in range(range_begin, len(soa_idx)):
|
| 641 |
+
codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
|
| 642 |
+
if codec_ids[0] == 32016:
|
| 643 |
+
codec_ids = codec_ids[1:]
|
| 644 |
+
codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
|
| 645 |
+
vocals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[0])
|
| 646 |
+
vocals.append(vocals_ids)
|
| 647 |
+
instrumentals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[1])
|
| 648 |
+
instrumentals.append(instrumentals_ids)
|
| 649 |
+
vocals = np.concatenate(vocals, axis=1)
|
| 650 |
+
instrumentals = np.concatenate(instrumentals, axis=1)
|
| 651 |
+
# Clean genres string to avoid filename being too long
|
| 652 |
+
genres_clean = sanitize_genres_for_filename(genres, max_length=80)
|
| 653 |
+
vocal_save_path = os.path.join(stage1_output_dir, f"{genres_clean}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_{random_id}_vtrack".replace('.', '@')+'.npy')
|
| 654 |
+
inst_save_path = os.path.join(stage1_output_dir, f"{genres_clean}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_{random_id}_itrack".replace('.', '@')+'.npy')
|
| 655 |
+
np.save(vocal_save_path, vocals)
|
| 656 |
+
np.save(inst_save_path, instrumentals)
|
| 657 |
+
stage1_output_set.append(vocal_save_path)
|
| 658 |
+
stage1_output_set.append(inst_save_path)
|
| 659 |
+
|
| 660 |
+
return stage1_output_set, stage2_output_dir, song_output_dir
|
| 661 |
+
|
| 662 |
+
# Load Stage 2 model and vocoder (load only once)
|
| 663 |
+
print("\n" + "="*60)
|
| 664 |
+
print("Loading Stage 2 model...")
|
| 665 |
+
print("="*60)
|
| 666 |
+
model_stage2 = AutoModelForCausalLM.from_pretrained(
|
| 667 |
+
stage2_model,
|
| 668 |
+
torch_dtype=torch.bfloat16,
|
| 669 |
+
attn_implementation="flash_attention_2", # Using flash_attention_2 for better performance
|
| 670 |
+
# device_map="auto",
|
| 671 |
+
)
|
| 672 |
+
model_stage2.to(device)
|
| 673 |
+
model_stage2.eval()
|
| 674 |
+
|
| 675 |
+
if torch.__version__ >= "2.0.0":
|
| 676 |
+
try:
|
| 677 |
+
model_stage2 = torch.compile(model_stage2)
|
| 678 |
+
except Exception as e:
|
| 679 |
+
print(f"Warning: torch.compile not available: {e}")
|
| 680 |
+
|
| 681 |
+
print("Loading vocoder...")
|
| 682 |
+
vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path)
|
| 683 |
+
|
| 684 |
+
# Batch process all songs - process each song completely before continuing to next
|
| 685 |
+
all_results = []
|
| 686 |
+
skipped_count = 0
|
| 687 |
+
for list_idx, music_data in enumerate(music_data_list):
|
| 688 |
+
# Calculate actual song index (considering start_idx offset)
|
| 689 |
+
song_idx = start_idx + list_idx
|
| 690 |
+
|
| 691 |
+
try:
|
| 692 |
+
# Compatible with genre and description fields
|
| 693 |
+
genres = music_data.get('genre') or music_data.get('description', '')
|
| 694 |
+
|
| 695 |
+
# Check processing status
|
| 696 |
+
stage1_done = False
|
| 697 |
+
stage2_done = False
|
| 698 |
+
stage3_done = False
|
| 699 |
+
song_output_dir = None
|
| 700 |
+
stage1_output_set = None
|
| 701 |
+
stage2_output_dir = None
|
| 702 |
+
|
| 703 |
+
if song_idx in song_status_map:
|
| 704 |
+
stage1_done, stage2_done, stage3_done, song_output_dir, stage1_output_set, stage2_output_dir = song_status_map[song_idx]
|
| 705 |
+
|
| 706 |
+
# If all completed, skip
|
| 707 |
+
if stage3_done:
|
| 708 |
+
print(f"\n{'='*60}")
|
| 709 |
+
print(f"⏭️ Skipping song {list_idx+1}/{len(music_data_list)} (index {song_idx}, fully completed)")
|
| 710 |
+
print(f"{'='*60}")
|
| 711 |
+
skipped_count += 1
|
| 712 |
+
continue
|
| 713 |
+
|
| 714 |
+
# Decide which stage to start from based on completion status
|
| 715 |
+
print(f"\n{'='*60}")
|
| 716 |
+
print(f"Starting to process song {list_idx+1}/{len(music_data_list)} (index {song_idx})")
|
| 717 |
+
if stage1_done:
|
| 718 |
+
print(f" ✓ Stage 1 completed, will start from Stage 2")
|
| 719 |
+
if stage2_done:
|
| 720 |
+
print(f" ✓ Stage 2 completed, will start from Stage 3")
|
| 721 |
+
print(f"{'='*60}")
|
| 722 |
+
|
| 723 |
+
# Stage 1: Generate audio tokens (if not completed)
|
| 724 |
+
if not stage1_done:
|
| 725 |
+
stage1_output_set, stage2_output_dir, song_output_dir = process_one_song(music_data, song_idx, len(music_data_list))
|
| 726 |
+
print(f"��� Stage 1 completed, generated {len(stage1_output_set)} files")
|
| 727 |
+
for f in stage1_output_set:
|
| 728 |
+
print(f" - {os.path.basename(f)}")
|
| 729 |
+
else:
|
| 730 |
+
print(f"⏭️ Skipping Stage 1 (completed)")
|
| 731 |
+
print(f" Using existing Stage 1 outputs:")
|
| 732 |
+
for f in stage1_output_set:
|
| 733 |
+
print(f" - {os.path.basename(f)}")
|
| 734 |
+
|
| 735 |
+
# Note: Do not unload Stage 1 model here, as subsequent songs still need it
|
| 736 |
+
# Stage 1 model will be unloaded uniformly after all songs are processed
|
| 737 |
+
|
| 738 |
+
# Stage 2: Process audio tokens (if not completed)
|
| 739 |
+
if not stage2_done:
|
| 740 |
+
print(f"\n--- Stage 2: Processing song {list_idx+1} (index {song_idx}) ---")
|
| 741 |
+
stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=args.stage2_batch_size)
|
| 742 |
+
print(f"✓ Stage 2 completed, generated {len(stage2_result)} files")
|
| 743 |
+
for f in stage2_result:
|
| 744 |
+
print(f" - {os.path.basename(f)}")
|
| 745 |
+
else:
|
| 746 |
+
print(f"\n⏭️ Skipping Stage 2 (completed)")
|
| 747 |
+
# Get existing stage2 results
|
| 748 |
+
stage2_result = []
|
| 749 |
+
if os.path.exists(stage2_output_dir):
|
| 750 |
+
for f in stage1_output_set:
|
| 751 |
+
basename = os.path.basename(f)
|
| 752 |
+
stage2_file = os.path.join(stage2_output_dir, basename)
|
| 753 |
+
if os.path.exists(stage2_file):
|
| 754 |
+
stage2_result.append(stage2_file)
|
| 755 |
+
print(f" Using existing Stage 2 outputs:")
|
| 756 |
+
for f in stage2_result:
|
| 757 |
+
print(f" - {os.path.basename(f)}")
|
| 758 |
+
|
| 759 |
+
# Stage 3: Reconstruct audio and mix (if not completed)
|
| 760 |
+
final_output = None
|
| 761 |
+
if not stage3_done:
|
| 762 |
+
print(f"\n--- Stage 3: Reconstructing audio for song {list_idx+1} (index {song_idx}) ---")
|
| 763 |
+
|
| 764 |
+
# reconstruct tracks
|
| 765 |
+
recons_output_dir = os.path.join(song_output_dir, "recons")
|
| 766 |
+
recons_mix_dir = os.path.join(recons_output_dir, 'mix')
|
| 767 |
+
os.makedirs(recons_mix_dir, exist_ok=True)
|
| 768 |
+
tracks = []
|
| 769 |
+
for npy in stage2_result:
|
| 770 |
+
codec_result = np.load(npy)
|
| 771 |
+
decodec_rlt=[]
|
| 772 |
+
with torch.no_grad():
|
| 773 |
+
decoded_waveform = codec_model.decode(torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
|
| 774 |
+
decoded_waveform = decoded_waveform.cpu().squeeze(0)
|
| 775 |
+
decodec_rlt.append(torch.as_tensor(decoded_waveform))
|
| 776 |
+
decodec_rlt = torch.cat(decodec_rlt, dim=-1)
|
| 777 |
+
save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
|
| 778 |
+
tracks.append(save_path)
|
| 779 |
+
save_audio(decodec_rlt, save_path, 16000)
|
| 780 |
+
|
| 781 |
+
# mix tracks
|
| 782 |
+
recons_mix = None
|
| 783 |
+
for inst_path in tracks:
|
| 784 |
+
try:
|
| 785 |
+
if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
|
| 786 |
+
and '_itrack' in inst_path:
|
| 787 |
+
# find pair
|
| 788 |
+
vocal_path = inst_path.replace('_itrack', '_vtrack')
|
| 789 |
+
if not os.path.exists(vocal_path):
|
| 790 |
+
continue
|
| 791 |
+
# mix
|
| 792 |
+
recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('_itrack', '_mixed'))
|
| 793 |
+
vocal_stem, sr = sf.read(inst_path)
|
| 794 |
+
instrumental_stem, _ = sf.read(vocal_path)
|
| 795 |
+
mix_stem = (vocal_stem + instrumental_stem) / 1
|
| 796 |
+
sf.write(recons_mix, mix_stem, sr)
|
| 797 |
+
except Exception as e:
|
| 798 |
+
print(e)
|
| 799 |
+
|
| 800 |
+
# vocoder to upsample audios
|
| 801 |
+
vocoder_output_dir = os.path.join(song_output_dir, 'vocoder')
|
| 802 |
+
vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
|
| 803 |
+
vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
|
| 804 |
+
os.makedirs(vocoder_mix_dir, exist_ok=True)
|
| 805 |
+
os.makedirs(vocoder_stems_dir, exist_ok=True)
|
| 806 |
+
|
| 807 |
+
for npy in stage2_result:
|
| 808 |
+
if '_itrack' in npy:
|
| 809 |
+
# Process instrumental
|
| 810 |
+
instrumental_output = process_audio(
|
| 811 |
+
npy,
|
| 812 |
+
os.path.join(vocoder_stems_dir, 'itrack.mp3'),
|
| 813 |
+
args.rescale,
|
| 814 |
+
args,
|
| 815 |
+
inst_decoder,
|
| 816 |
+
codec_model
|
| 817 |
+
)
|
| 818 |
+
else:
|
| 819 |
+
# Process vocal
|
| 820 |
+
vocal_output = process_audio(
|
| 821 |
+
npy,
|
| 822 |
+
os.path.join(vocoder_stems_dir, 'vtrack.mp3'),
|
| 823 |
+
args.rescale,
|
| 824 |
+
args,
|
| 825 |
+
vocal_decoder,
|
| 826 |
+
codec_model
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
# mix tracks
|
| 830 |
+
vocoder_mix = None
|
| 831 |
+
try:
|
| 832 |
+
mix_output = instrumental_output + vocal_output
|
| 833 |
+
vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
|
| 834 |
+
save_audio(mix_output, vocoder_mix, 44100, args.rescale)
|
| 835 |
+
print(f"Created mix: {vocoder_mix}")
|
| 836 |
+
except RuntimeError as e:
|
| 837 |
+
print(e)
|
| 838 |
+
print(f"Mix failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
|
| 839 |
+
|
| 840 |
+
# Post process
|
| 841 |
+
if recons_mix and vocoder_mix:
|
| 842 |
+
final_output = os.path.join(song_output_dir, os.path.basename(recons_mix))
|
| 843 |
+
replace_low_freq_with_energy_matched(
|
| 844 |
+
a_file=recons_mix, # 16kHz
|
| 845 |
+
b_file=vocoder_mix, # 48kHz
|
| 846 |
+
c_file=final_output,
|
| 847 |
+
cutoff_freq=5500.0
|
| 848 |
+
)
|
| 849 |
+
print(f"✓ Song {list_idx+1} (index {song_idx}) completed! Output: {final_output}")
|
| 850 |
+
else:
|
| 851 |
+
print(f"\n⏭️ Skipping Stage 3 (completed)")
|
| 852 |
+
# Find final output file (usually in song_dir root directory)
|
| 853 |
+
# First check root directory
|
| 854 |
+
root_files = [f for f in os.listdir(song_output_dir) if f.endswith('_mixed.mp3')]
|
| 855 |
+
if root_files:
|
| 856 |
+
final_output = os.path.join(song_output_dir, root_files[0])
|
| 857 |
+
else:
|
| 858 |
+
# If root directory doesn't have it, traverse subdirectories to find
|
| 859 |
+
for root, dirs, files in os.walk(song_output_dir):
|
| 860 |
+
for f in files:
|
| 861 |
+
if f.endswith('_mixed.mp3'):
|
| 862 |
+
final_output = os.path.join(root, f)
|
| 863 |
+
break
|
| 864 |
+
if final_output:
|
| 865 |
+
break
|
| 866 |
+
if final_output:
|
| 867 |
+
print(f" Final output: {final_output}")
|
| 868 |
+
|
| 869 |
+
all_results.append({
|
| 870 |
+
'song_idx': song_idx,
|
| 871 |
+
'genres': genres,
|
| 872 |
+
'output_path': final_output if recons_mix and vocoder_mix else None
|
| 873 |
+
})
|
| 874 |
+
|
| 875 |
+
except Exception as e:
|
| 876 |
+
print(f"✗ Error processing song {list_idx+1} (index {song_idx}): {e}")
|
| 877 |
+
import traceback
|
| 878 |
+
traceback.print_exc()
|
| 879 |
+
continue
|
| 880 |
+
|
| 881 |
+
# After all songs are processed, unload models to free memory
|
| 882 |
+
if not args.disable_offload_model:
|
| 883 |
+
print("\nCleaning up models to free memory...")
|
| 884 |
+
if 'model' in locals():
|
| 885 |
+
model.cpu()
|
| 886 |
+
del model
|
| 887 |
+
if 'model_stage2' in locals():
|
| 888 |
+
model_stage2.cpu()
|
| 889 |
+
del model_stage2
|
| 890 |
+
torch.cuda.empty_cache()
|
| 891 |
+
print("Models unloaded")
|
| 892 |
+
|
| 893 |
+
print("\n" + "="*60)
|
| 894 |
+
print("Batch generation complete!")
|
| 895 |
+
newly_processed = len([r for r in all_results if r.get('output_path')])
|
| 896 |
+
print(f"✓ Newly processed: {newly_processed} songs")
|
| 897 |
+
if skipped_count > 0:
|
| 898 |
+
print(f"⏭️ Skipped (already completed): {skipped_count} songs")
|
| 899 |
+
print(f"📊 Total completed: {newly_processed + skipped_count} songs")
|
| 900 |
+
print("="*60)
|
| 901 |
+
for result in all_results:
|
| 902 |
+
if result.get('output_path'):
|
| 903 |
+
print(f"Song {result['song_idx']+1}: {result['output_path']}")
|
| 904 |
+
|
baseline_generate/yue/mmtokenizer.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AbstractTokenizer(ABC):
|
| 6 |
+
"""Abstract class for tokenizer."""
|
| 7 |
+
|
| 8 |
+
def __init__(self, name):
|
| 9 |
+
self.name = name
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
@property
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def vocab_size(self):
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
@property
|
| 18 |
+
@abstractmethod
|
| 19 |
+
def vocab(self):
|
| 20 |
+
"""Dictionary from vocab text token to id token."""
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def inv_vocab(self):
|
| 26 |
+
"""Dictionary from vocab id token to text token."""
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def tokenize(self, text):
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
def detokenize(self, token_ids):
|
| 34 |
+
raise NotImplementedError('detokenizer is not implemented for {} '
|
| 35 |
+
'tokenizer'.format(self.name))
|
| 36 |
+
|
| 37 |
+
@property
|
| 38 |
+
def cls(self):
|
| 39 |
+
raise NotImplementedError('CLS is not provided for {} '
|
| 40 |
+
'tokenizer'.format(self.name))
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def sep(self):
|
| 44 |
+
raise NotImplementedError('SEP is not provided for {} '
|
| 45 |
+
'tokenizer'.format(self.name))
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def pad(self):
|
| 49 |
+
raise NotImplementedError('PAD is not provided for {} '
|
| 50 |
+
'tokenizer'.format(self.name))
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def eod(self):
|
| 54 |
+
raise NotImplementedError('EOD is not provided for {} '
|
| 55 |
+
'tokenizer'.format(self.name))
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def mask(self):
|
| 59 |
+
raise NotImplementedError('MASK is not provided for {} '
|
| 60 |
+
'tokenizer'.format(self.name))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class _SentencePieceTokenizer(AbstractTokenizer):
|
| 64 |
+
"""SentencePieceTokenizer-Megatron wrapper"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, model_file, vocab_extra_ids=0):
|
| 67 |
+
name = 'SentencePieceTokenizer'
|
| 68 |
+
super().__init__(name)
|
| 69 |
+
|
| 70 |
+
import sentencepiece
|
| 71 |
+
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
|
| 72 |
+
self._initalize(vocab_extra_ids)
|
| 73 |
+
|
| 74 |
+
def _populate_vocab(self):
|
| 75 |
+
self._vocab = {}
|
| 76 |
+
self._inv_vocab = {}
|
| 77 |
+
|
| 78 |
+
for i in range(len(self.tokenizer)):
|
| 79 |
+
t = self.tokenizer.id_to_piece(i)
|
| 80 |
+
self._inv_vocab[i] = t
|
| 81 |
+
self._vocab[t] = i
|
| 82 |
+
|
| 83 |
+
def _initalize(self, vocab_extra_ids):
|
| 84 |
+
self._populate_vocab()
|
| 85 |
+
self._special_tokens = {}
|
| 86 |
+
self._inv_special_tokens = {}
|
| 87 |
+
|
| 88 |
+
self._t5_tokens = []
|
| 89 |
+
|
| 90 |
+
def _add_special_token(t):
|
| 91 |
+
if t not in self._vocab:
|
| 92 |
+
next_id = len(self._vocab)
|
| 93 |
+
self._vocab[t] = next_id
|
| 94 |
+
self._inv_vocab[next_id] = t
|
| 95 |
+
self._special_tokens[t] = self._vocab[t]
|
| 96 |
+
self._inv_special_tokens[self._vocab[t]] = t
|
| 97 |
+
|
| 98 |
+
_add_special_token('<CLS>')
|
| 99 |
+
self._cls_id = self._vocab['<CLS>']
|
| 100 |
+
_add_special_token('<SEP>')
|
| 101 |
+
self._sep_id = self._vocab['<SEP>']
|
| 102 |
+
_add_special_token('<EOD>')
|
| 103 |
+
self._eod_id = self._vocab['<EOD>']
|
| 104 |
+
_add_special_token('<MASK>')
|
| 105 |
+
self._mask_id = self._vocab['<MASK>']
|
| 106 |
+
|
| 107 |
+
pad_id = self.tokenizer.pad_id()
|
| 108 |
+
try:
|
| 109 |
+
pad_token = self.tokenizer.id_to_piece(pad_id)
|
| 110 |
+
except IndexError:
|
| 111 |
+
pad_token = '<PAD>'
|
| 112 |
+
_add_special_token(pad_token)
|
| 113 |
+
self._pad_id = self._vocab[pad_token]
|
| 114 |
+
|
| 115 |
+
bos_id = self.tokenizer.bos_id()
|
| 116 |
+
try:
|
| 117 |
+
bos_token = self.tokenizer.id_to_piece(bos_id)
|
| 118 |
+
except IndexError:
|
| 119 |
+
bos_token = '<BOS>'
|
| 120 |
+
_add_special_token(bos_token)
|
| 121 |
+
self._bos_id = self._vocab[bos_token]
|
| 122 |
+
|
| 123 |
+
eos_id = self.tokenizer.eos_id()
|
| 124 |
+
try:
|
| 125 |
+
eos_token = self.tokenizer.id_to_piece(eos_id)
|
| 126 |
+
except IndexError:
|
| 127 |
+
eos_token = '<EOS>'
|
| 128 |
+
_add_special_token(eos_token)
|
| 129 |
+
self._eos_id = self._vocab[eos_token]
|
| 130 |
+
|
| 131 |
+
for i in range(vocab_extra_ids):
|
| 132 |
+
t = "<extra_id_{}>".format(i)
|
| 133 |
+
_add_special_token(t)
|
| 134 |
+
self._t5_tokens += [t]
|
| 135 |
+
|
| 136 |
+
@property
|
| 137 |
+
def vocab_size(self):
|
| 138 |
+
return len(self._vocab)
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
def vocab(self):
|
| 142 |
+
return self._vocab
|
| 143 |
+
|
| 144 |
+
@property
|
| 145 |
+
def inv_vocab(self):
|
| 146 |
+
return self._inv_vocab
|
| 147 |
+
|
| 148 |
+
@property
|
| 149 |
+
def decoder(self):
|
| 150 |
+
return self._inv_vocab
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def encoder(self):
|
| 154 |
+
return self._vocab
|
| 155 |
+
|
| 156 |
+
# From:
|
| 157 |
+
# https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89
|
| 158 |
+
def tokenize(self, text):
|
| 159 |
+
ids = []
|
| 160 |
+
idx = 0
|
| 161 |
+
|
| 162 |
+
while 1:
|
| 163 |
+
indices = {}
|
| 164 |
+
for token in self._special_tokens:
|
| 165 |
+
try:
|
| 166 |
+
indices[token] = text[idx:].index(token)
|
| 167 |
+
except ValueError:
|
| 168 |
+
continue
|
| 169 |
+
if len(indices) == 0:
|
| 170 |
+
break
|
| 171 |
+
|
| 172 |
+
next_token = min(indices, key=indices.get)
|
| 173 |
+
next_idx = idx + indices[next_token]
|
| 174 |
+
|
| 175 |
+
ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx]))
|
| 176 |
+
ids.append(self._special_tokens[next_token])
|
| 177 |
+
idx = next_idx + len(next_token)
|
| 178 |
+
|
| 179 |
+
ids.extend(self.tokenizer.encode_as_ids(text[idx:]))
|
| 180 |
+
return ids
|
| 181 |
+
|
| 182 |
+
# From:
|
| 183 |
+
# https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L125
|
| 184 |
+
def detokenize(self, ids):
|
| 185 |
+
text = ""
|
| 186 |
+
last_i = 0
|
| 187 |
+
|
| 188 |
+
for i, id in enumerate(ids):
|
| 189 |
+
if id in self._inv_special_tokens:
|
| 190 |
+
text += self.tokenizer.decode_ids(ids[last_i:i]) + " "
|
| 191 |
+
text += self._inv_special_tokens[id] + " "
|
| 192 |
+
last_i = i + 1
|
| 193 |
+
|
| 194 |
+
text += self.tokenizer.decode_ids(ids[last_i:])
|
| 195 |
+
return text
|
| 196 |
+
|
| 197 |
+
@property
|
| 198 |
+
def cls(self):
|
| 199 |
+
return self._cls_id
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def sep(self):
|
| 203 |
+
return self._sep_id
|
| 204 |
+
|
| 205 |
+
@property
|
| 206 |
+
def pad(self):
|
| 207 |
+
return self._pad_id
|
| 208 |
+
|
| 209 |
+
@property
|
| 210 |
+
def bos_token_id(self):
|
| 211 |
+
return self._bos_id
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def bos(self):
|
| 215 |
+
return self._bos_id
|
| 216 |
+
|
| 217 |
+
@property
|
| 218 |
+
def eod(self):
|
| 219 |
+
return self._eod_id
|
| 220 |
+
|
| 221 |
+
@property
|
| 222 |
+
def eos_token_id(self):
|
| 223 |
+
return self._eos_id
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def eos(self):
|
| 227 |
+
return self._eos_id
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
def mask(self):
|
| 231 |
+
return self._mask_id
|
| 232 |
+
|
| 233 |
+
@property
|
| 234 |
+
def additional_special_tokens_ids(self):
|
| 235 |
+
return [self.vocab[k] for k in self._t5_tokens]
|
| 236 |
+
|
| 237 |
+
class _MMSentencePieceTokenizer(_SentencePieceTokenizer):
|
| 238 |
+
"""SentencePieceTokenizer-Megatron wrapper"""
|
| 239 |
+
|
| 240 |
+
def __init__(self, model_file, vocab_extra_ids=0):
|
| 241 |
+
super().__init__(model_file, vocab_extra_ids)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def _initalize(self, vocab_extra_ids):
|
| 245 |
+
self._populate_vocab()
|
| 246 |
+
self._special_tokens = {}
|
| 247 |
+
self._inv_special_tokens = {}
|
| 248 |
+
|
| 249 |
+
self._t5_tokens = []
|
| 250 |
+
|
| 251 |
+
def _add_special_token(t):
|
| 252 |
+
if t not in self._vocab:
|
| 253 |
+
next_id = len(self._vocab)
|
| 254 |
+
self._vocab[t] = next_id
|
| 255 |
+
self._inv_vocab[next_id] = t
|
| 256 |
+
self._special_tokens[t] = self._vocab[t]
|
| 257 |
+
self._inv_special_tokens[self._vocab[t]] = t
|
| 258 |
+
|
| 259 |
+
_add_special_token('<CLS>')
|
| 260 |
+
self._cls_id = self._vocab['<CLS>']
|
| 261 |
+
_add_special_token('<SEP>')
|
| 262 |
+
self._sep_id = self._vocab['<SEP>']
|
| 263 |
+
_add_special_token('<EOD>')
|
| 264 |
+
self._eod_id = self._vocab['<EOD>']
|
| 265 |
+
_add_special_token('<MASK>')
|
| 266 |
+
self._mask_id = self._vocab['<MASK>']
|
| 267 |
+
|
| 268 |
+
_add_special_token('<SOA>')
|
| 269 |
+
self._soa_id = self._vocab['<SOA>']
|
| 270 |
+
_add_special_token('<EOA>')
|
| 271 |
+
self._eoa_id = self._vocab['<EOA>']
|
| 272 |
+
_add_special_token('<SOV>')
|
| 273 |
+
self._sov_id = self._vocab['<SOV>']
|
| 274 |
+
_add_special_token('<EOV>')
|
| 275 |
+
self._eov_id = self._vocab['<EOV>']
|
| 276 |
+
_add_special_token('<SOI>')
|
| 277 |
+
self._soi_id = self._vocab['<SOI>']
|
| 278 |
+
_add_special_token('<EOI>')
|
| 279 |
+
self._eoi_id = self._vocab['<EOI>']
|
| 280 |
+
_add_special_token('<s_local>')
|
| 281 |
+
self._s_local_id = self._vocab['<s_local>']
|
| 282 |
+
_add_special_token('<e_local>')
|
| 283 |
+
self._e_local_id = self._vocab['<e_local>']
|
| 284 |
+
_add_special_token('<s_global>')
|
| 285 |
+
self._s_global_id = self._vocab['<s_global>']
|
| 286 |
+
_add_special_token('<e_global>')
|
| 287 |
+
self._e_global_id = self._vocab['<e_global>']
|
| 288 |
+
_add_special_token('<stage_1>')
|
| 289 |
+
self._stage_1_id = self._vocab['<stage_1>']
|
| 290 |
+
_add_special_token('<stage_2>')
|
| 291 |
+
self._stage_2_id = self._vocab['<stage_2>']
|
| 292 |
+
pad_id = self.tokenizer.pad_id()
|
| 293 |
+
try:
|
| 294 |
+
pad_token = self.tokenizer.id_to_piece(pad_id)
|
| 295 |
+
except IndexError:
|
| 296 |
+
pad_token = '<PAD>'
|
| 297 |
+
_add_special_token(pad_token)
|
| 298 |
+
self._pad_id = self._vocab[pad_token]
|
| 299 |
+
|
| 300 |
+
bos_id = self.tokenizer.bos_id()
|
| 301 |
+
try:
|
| 302 |
+
bos_token = self.tokenizer.id_to_piece(bos_id)
|
| 303 |
+
except IndexError:
|
| 304 |
+
bos_token = '<BOS>'
|
| 305 |
+
_add_special_token(bos_token)
|
| 306 |
+
self._bos_id = self._vocab[bos_token]
|
| 307 |
+
|
| 308 |
+
eos_id = self.tokenizer.eos_id()
|
| 309 |
+
try:
|
| 310 |
+
eos_token = self.tokenizer.id_to_piece(eos_id)
|
| 311 |
+
except IndexError:
|
| 312 |
+
eos_token = '<EOS>'
|
| 313 |
+
_add_special_token(eos_token)
|
| 314 |
+
self._eos_id = self._vocab[eos_token]
|
| 315 |
+
|
| 316 |
+
for i in range(vocab_extra_ids):
|
| 317 |
+
t = "<extra_id_{}>".format(i)
|
| 318 |
+
_add_special_token(t)
|
| 319 |
+
self._t5_tokens += [t]
|
| 320 |
+
|
| 321 |
+
@property
|
| 322 |
+
def soa(self):
|
| 323 |
+
return self._soa_id
|
| 324 |
+
|
| 325 |
+
@property
|
| 326 |
+
def eoa(self):
|
| 327 |
+
return self._eoa_id
|
| 328 |
+
|
| 329 |
+
@property
|
| 330 |
+
def sov(self):
|
| 331 |
+
return self._sov_id
|
| 332 |
+
|
| 333 |
+
@property
|
| 334 |
+
def eov(self):
|
| 335 |
+
return self._eov_id
|
| 336 |
+
|
| 337 |
+
@property
|
| 338 |
+
def soi(self):
|
| 339 |
+
return self._soi_id
|
| 340 |
+
|
| 341 |
+
@property
|
| 342 |
+
def eoi(self):
|
| 343 |
+
return self._eoi_id
|
| 344 |
+
|
| 345 |
+
@property
|
| 346 |
+
def s_local(self):
|
| 347 |
+
return self._s_local_id
|
| 348 |
+
|
| 349 |
+
@property
|
| 350 |
+
def e_local(self):
|
| 351 |
+
return self._e_local_id
|
| 352 |
+
|
| 353 |
+
@property
|
| 354 |
+
def s_global(self):
|
| 355 |
+
return self._s_global_id
|
| 356 |
+
|
| 357 |
+
@property
|
| 358 |
+
def e_global(self):
|
| 359 |
+
return self._e_global_id
|
| 360 |
+
|
| 361 |
+
@property
|
| 362 |
+
def stage_1(self):
|
| 363 |
+
return self._stage_1_id
|
| 364 |
+
|
| 365 |
+
@property
|
| 366 |
+
def stage_2(self):
|
| 367 |
+
return self._stage_2_id
|
data_pipeline/lyrics_gene/__pycache__/filter_all_cn.cpython-311.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
data_pipeline/lyrics_gene/__pycache__/filter_all_en.cpython-311.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
data_pipeline/lyrics_gene/__pycache__/gen_lyrics_cn.cpython-311.pyc
ADDED
|
Binary file (31 kB). View file
|
|
|
data_pipeline/lyrics_gene/filter_all_cn.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import torch
|
| 5 |
+
from sentence_transformers import SentenceTransformer, util
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
# os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890"
|
| 9 |
+
# os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890"
|
| 10 |
+
|
| 11 |
+
# Use HuggingFace mirror site
|
| 12 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 13 |
+
# Set HuggingFace cache directory so SentenceTransformer can recognize downloaded models
|
| 14 |
+
os.environ["HF_HOME"] = os.path.expanduser("~/.cache/huggingface")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def clean_lyrics(lyrics):
|
| 18 |
+
"""
|
| 19 |
+
Clean lyrics by removing segment tags, timestamp tags, and newlines, keeping only pure lyric text
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
lyrics: Raw lyrics text (contains segment tags like [Verse 1], timestamps like [00:07.00], and newlines)
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Cleaned lyrics text (plain text, no tags and newlines)
|
| 26 |
+
"""
|
| 27 |
+
# Use regex to remove all [tag] format content (including segment tags and timestamps)
|
| 28 |
+
# Pattern matching [any content]
|
| 29 |
+
cleaned = re.sub(r'\[.*?\]', '', lyrics)
|
| 30 |
+
|
| 31 |
+
# Remove all newlines, replace with spaces
|
| 32 |
+
cleaned = cleaned.replace('\n', ' ')
|
| 33 |
+
|
| 34 |
+
# Remove extra spaces (replace multiple consecutive spaces with single space)
|
| 35 |
+
cleaned = re.sub(r'\s+', ' ', cleaned)
|
| 36 |
+
|
| 37 |
+
# Remove leading and trailing spaces
|
| 38 |
+
cleaned = cleaned.strip()
|
| 39 |
+
|
| 40 |
+
return cleaned
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_music_data(input_file, max_count=None):
|
| 44 |
+
"""
|
| 45 |
+
Load music data from jsonl file
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
input_file: Path to input jsonl file
|
| 49 |
+
max_count: Maximum number to read, None means read all
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
List of music data
|
| 53 |
+
"""
|
| 54 |
+
music_list = []
|
| 55 |
+
print(f"Loading music data: {input_file}")
|
| 56 |
+
if max_count:
|
| 57 |
+
print(f"Limiting to first {max_count} songs")
|
| 58 |
+
|
| 59 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 60 |
+
for line in tqdm(f, desc="Loading data"):
|
| 61 |
+
try:
|
| 62 |
+
data = json.loads(line.strip())
|
| 63 |
+
# Ensure required fields are present
|
| 64 |
+
if 'description' in data and 'lyrics' in data:
|
| 65 |
+
music_list.append(data)
|
| 66 |
+
# If reached maximum count, stop reading
|
| 67 |
+
if max_count and len(music_list) >= max_count:
|
| 68 |
+
break
|
| 69 |
+
except json.JSONDecodeError:
|
| 70 |
+
continue
|
| 71 |
+
print(f"Successfully loaded {len(music_list)} songs")
|
| 72 |
+
return music_list
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def deduplicate_music(music_list, texts, model, threshold=0.90, output_file=None, save_interval=10000, matrix_save_dir=None):
|
| 76 |
+
"""
|
| 77 |
+
Deduplicate music data based on text similarity
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
music_list: List of music data
|
| 81 |
+
texts: List of texts for comparison
|
| 82 |
+
model: SentenceTransformer model
|
| 83 |
+
threshold: Similarity threshold
|
| 84 |
+
output_file: Output file path, if provided supports incremental saving
|
| 85 |
+
save_interval: Save every N valid songs processed
|
| 86 |
+
matrix_save_dir: Directory to save matrices
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Deduplicated music data list
|
| 90 |
+
"""
|
| 91 |
+
print(f"Computing embeddings for {len(texts)} texts...")
|
| 92 |
+
embeddings = model.encode(texts, convert_to_tensor=True, show_progress_bar=True)
|
| 93 |
+
|
| 94 |
+
print("Computing similarity matrix...")
|
| 95 |
+
cos_scores = util.pytorch_cos_sim(embeddings, embeddings)
|
| 96 |
+
|
| 97 |
+
# Save similarity matrix and embeddings
|
| 98 |
+
if matrix_save_dir:
|
| 99 |
+
os.makedirs(matrix_save_dir, exist_ok=True)
|
| 100 |
+
embeddings_path = os.path.join(matrix_save_dir, 'embeddings.pt')
|
| 101 |
+
cos_scores_path = os.path.join(matrix_save_dir, 'cos_scores.pt')
|
| 102 |
+
print(f"Saving embeddings to: {embeddings_path}")
|
| 103 |
+
torch.save(embeddings.cpu(), embeddings_path)
|
| 104 |
+
print(f"Saving similarity matrix to: {cos_scores_path}")
|
| 105 |
+
torch.save(cos_scores.cpu(), cos_scores_path)
|
| 106 |
+
print("Matrix saving complete!")
|
| 107 |
+
|
| 108 |
+
print(f"Deduplicating (threshold: {threshold})...")
|
| 109 |
+
keep_idx = []
|
| 110 |
+
removed = set()
|
| 111 |
+
|
| 112 |
+
# If output file provided, open in write mode
|
| 113 |
+
f = None
|
| 114 |
+
if output_file:
|
| 115 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
| 116 |
+
f = open(output_file, 'w', encoding='utf-8')
|
| 117 |
+
|
| 118 |
+
saved_count = 0
|
| 119 |
+
|
| 120 |
+
for i in tqdm(range(len(music_list)), desc="Deduplication progress"):
|
| 121 |
+
if i in removed:
|
| 122 |
+
continue
|
| 123 |
+
keep_idx.append(i)
|
| 124 |
+
|
| 125 |
+
# If incremental saving enabled, save every save_interval songs
|
| 126 |
+
if f and len(keep_idx) - saved_count >= save_interval:
|
| 127 |
+
# Save all valid songs from saved_count to current
|
| 128 |
+
for idx in range(saved_count, len(keep_idx)):
|
| 129 |
+
music = music_list[keep_idx[idx]]
|
| 130 |
+
f.write(json.dumps(music, ensure_ascii=False) + '\n')
|
| 131 |
+
f.flush() # Ensure write to disk
|
| 132 |
+
saved_count = len(keep_idx)
|
| 133 |
+
print(f"Saved {saved_count} valid songs to file")
|
| 134 |
+
|
| 135 |
+
for j in range(i+1, len(music_list)):
|
| 136 |
+
if cos_scores[i][j] > threshold:
|
| 137 |
+
removed.add(j)
|
| 138 |
+
|
| 139 |
+
# Save remaining valid songs
|
| 140 |
+
if f:
|
| 141 |
+
for idx in range(saved_count, len(keep_idx)):
|
| 142 |
+
music = music_list[keep_idx[idx]]
|
| 143 |
+
f.write(json.dumps(music, ensure_ascii=False) + '\n')
|
| 144 |
+
f.close()
|
| 145 |
+
print(f"Saved all {len(keep_idx)} valid songs to file")
|
| 146 |
+
|
| 147 |
+
deduped_music_list = [music_list[i] for i in keep_idx]
|
| 148 |
+
print(f"Deduplication complete: {len(music_list)} -> {len(deduped_music_list)} (removed {len(removed)} songs)")
|
| 149 |
+
|
| 150 |
+
return deduped_music_list
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def dedup_by_description_and_lyrics(input_file, output_file, threshold=0.95, max_count=None, device='cuda:1', save_interval=10000, matrix_save_dir=None):
|
| 154 |
+
"""
|
| 155 |
+
Method 2: Deduplicate based on description + lyrics
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
input_file: Path to input jsonl file
|
| 159 |
+
output_file: Path to output jsonl file
|
| 160 |
+
threshold: Similarity threshold
|
| 161 |
+
max_count: Maximum number to read, None means read all
|
| 162 |
+
device: Device to use, default cuda:1 (GPU1)
|
| 163 |
+
save_interval: Save every N valid songs processed, default 10000
|
| 164 |
+
matrix_save_dir: Directory to save matrices, if provided saves embeddings and similarity matrix
|
| 165 |
+
"""
|
| 166 |
+
print("\n========== Method 2: Deduplicate based on description + lyrics ==========")
|
| 167 |
+
print(f"Using device: {device}")
|
| 168 |
+
|
| 169 |
+
# Load data
|
| 170 |
+
music_list = load_music_data(input_file, max_count=max_count)
|
| 171 |
+
|
| 172 |
+
# Extract combined text from description + lyrics
|
| 173 |
+
combined_texts = []
|
| 174 |
+
for music in music_list:
|
| 175 |
+
description = music.get('description', '')
|
| 176 |
+
lyrics = music.get('lyrics', '')
|
| 177 |
+
# Clean lyrics, remove structure tags
|
| 178 |
+
cleaned_lyrics = clean_lyrics(lyrics)
|
| 179 |
+
# Concatenate description and cleaned lyrics (separated by delimiter)
|
| 180 |
+
combined_text = f"{description} [SEP] {cleaned_lyrics}"
|
| 181 |
+
combined_texts.append(combined_text)
|
| 182 |
+
|
| 183 |
+
# Load Chinese model and specify device
|
| 184 |
+
# Check if local model exists, if so use local path directly to avoid re-downloading
|
| 185 |
+
hf_home = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
|
| 186 |
+
model_cache_dir = os.path.join(hf_home, "hub", "models--shibing624--text2vec-bge-large-chinese", "snapshots")
|
| 187 |
+
|
| 188 |
+
# Find model snapshot directory
|
| 189 |
+
local_model_path = None
|
| 190 |
+
if os.path.exists(model_cache_dir):
|
| 191 |
+
snapshots = [d for d in os.listdir(model_cache_dir) if os.path.isdir(os.path.join(model_cache_dir, d))]
|
| 192 |
+
if snapshots:
|
| 193 |
+
# Use latest snapshot (usually unique)
|
| 194 |
+
local_model_path = os.path.join(model_cache_dir, snapshots[0])
|
| 195 |
+
if os.path.exists(os.path.join(local_model_path, "config.json")):
|
| 196 |
+
print(f"Detected local model, using path: {local_model_path}")
|
| 197 |
+
model = SentenceTransformer(local_model_path, device=device)
|
| 198 |
+
else:
|
| 199 |
+
local_model_path = None
|
| 200 |
+
|
| 201 |
+
if local_model_path is None:
|
| 202 |
+
print(f"Loading model to {device}...")
|
| 203 |
+
model = SentenceTransformer('shibing624/text2vec-bge-large-chinese', device=device)
|
| 204 |
+
|
| 205 |
+
# Deduplicate (supports incremental saving)
|
| 206 |
+
deduped_music_list = deduplicate_music(
|
| 207 |
+
music_list,
|
| 208 |
+
combined_texts,
|
| 209 |
+
model,
|
| 210 |
+
threshold,
|
| 211 |
+
output_file=output_file,
|
| 212 |
+
save_interval=save_interval,
|
| 213 |
+
matrix_save_dir=matrix_save_dir
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Deduplication function already handled saving, just print info here
|
| 217 |
+
if output_file:
|
| 218 |
+
print(f"✓ Save complete! Remaining {len(deduped_music_list)} songs after deduplication\n")
|
| 219 |
+
else:
|
| 220 |
+
# If no output file provided, save once here (compatibility with old code)
|
| 221 |
+
print(f"Saving results to: {output_file}")
|
| 222 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
| 223 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 224 |
+
for music in deduped_music_list:
|
| 225 |
+
f.write(json.dumps(music, ensure_ascii=False) + '\n')
|
| 226 |
+
print(f"✓ Save complete! Remaining {len(deduped_music_list)} songs after deduplication\n")
|
| 227 |
+
|
| 228 |
+
return deduped_music_list
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
if __name__ == '__main__':
|
| 232 |
+
# Input file path
|
| 233 |
+
input_file = 'lrc_4w_single_pro_des.jsonl'
|
| 234 |
+
|
| 235 |
+
# Output file path
|
| 236 |
+
output_file = 'filter_all_4w.jsonl'
|
| 237 |
+
|
| 238 |
+
# Matrix save directory
|
| 239 |
+
matrix_save_dir = 'generate_lrc'
|
| 240 |
+
|
| 241 |
+
# Set maximum read count (for testing, None means read all)
|
| 242 |
+
max_count = None # Test first 5 songs
|
| 243 |
+
|
| 244 |
+
# Deduplicate based on description + lyrics
|
| 245 |
+
print("\nDeduplicating based on description + lyrics")
|
| 246 |
+
dedup_by_description_and_lyrics(
|
| 247 |
+
input_file,
|
| 248 |
+
output_file,
|
| 249 |
+
threshold=0.90,
|
| 250 |
+
max_count=max_count,
|
| 251 |
+
device='cuda:7',
|
| 252 |
+
save_interval=10000, # Save every 10000 valid songs
|
| 253 |
+
matrix_save_dir=matrix_save_dir # Save similarity matrix
|
| 254 |
+
)
|
| 255 |
+
print(f"\nComplete! Results saved to: {output_file}")
|
| 256 |
+
print(f"Similarity matrix saved to: {matrix_save_dir}")
|
| 257 |
+
# Test lyrics cleaning effect
|
| 258 |
+
# print("\n========== Test Lyrics Cleaning Effect ==========")
|
| 259 |
+
# music_list = load_music_data(input_file, max_count=max_count)
|
| 260 |
+
|
| 261 |
+
# print("\n" + "="*80)
|
| 262 |
+
# for i, music in enumerate(music_list, 1):
|
| 263 |
+
# print(f"\n[Song {i}]")
|
| 264 |
+
# print(f"Description: {music.get('description', '')}")
|
| 265 |
+
# print("\n--- Original Lyrics ---")
|
| 266 |
+
# original_lyrics = music.get('lyrics', '')
|
| 267 |
+
# print(original_lyrics[:500] + "..." if len(original_lyrics) > 500 else original_lyrics)
|
| 268 |
+
# print("\n--- Cleaned Lyrics ---")
|
| 269 |
+
# cleaned_lyrics = clean_lyrics(original_lyrics)
|
| 270 |
+
# print(cleaned_lyrics)
|
| 271 |
+
# print("\n" + "-"*80)
|
| 272 |
+
|
data_pipeline/lyrics_gene/filter_all_en.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import torch
|
| 5 |
+
from sentence_transformers import SentenceTransformer, util
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
# os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890"
|
| 9 |
+
# os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890"
|
| 10 |
+
|
| 11 |
+
# Use HuggingFace mirror site
|
| 12 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 13 |
+
# Set HuggingFace cache directory so SentenceTransformer can recognize downloaded models
|
| 14 |
+
os.environ["HF_HOME"] = os.path.expanduser("~/.cache/huggingface")
|
| 15 |
+
|
| 16 |
+
MODEL_EN = "sentence-transformers/all-MiniLM-L6-v2"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def clean_lyrics(lyrics):
|
| 20 |
+
"""
|
| 21 |
+
Clean lyrics by removing segment tags, timestamp tags, and newlines, keeping only pure lyric text
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
lyrics: Raw lyrics text (contains segment tags like [Verse 1], timestamps like [00:07.00], and newlines)
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Cleaned lyrics text (plain text, no tags and newlines)
|
| 28 |
+
"""
|
| 29 |
+
# Use regex to remove all [tag] format content (including segment tags and timestamps)
|
| 30 |
+
# Pattern matching [any content]
|
| 31 |
+
cleaned = re.sub(r'\[.*?\]', '', lyrics)
|
| 32 |
+
|
| 33 |
+
# Remove all newlines, replace with spaces
|
| 34 |
+
cleaned = cleaned.replace('\n', ' ')
|
| 35 |
+
|
| 36 |
+
# Remove extra spaces (replace multiple consecutive spaces with single space)
|
| 37 |
+
cleaned = re.sub(r'\s+', ' ', cleaned)
|
| 38 |
+
|
| 39 |
+
# Remove leading and trailing spaces
|
| 40 |
+
cleaned = cleaned.strip()
|
| 41 |
+
|
| 42 |
+
return cleaned
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_music_data(input_file, max_count=None):
|
| 46 |
+
"""
|
| 47 |
+
Load music data from jsonl file
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
input_file: Path to input jsonl file
|
| 51 |
+
max_count: Maximum number to read, None means read all
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
List of music data
|
| 55 |
+
"""
|
| 56 |
+
music_list = []
|
| 57 |
+
print(f"Loading music data: {input_file}")
|
| 58 |
+
if max_count:
|
| 59 |
+
print(f"Limiting to first {max_count} songs")
|
| 60 |
+
|
| 61 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 62 |
+
for line in tqdm(f, desc="Loading data"):
|
| 63 |
+
try:
|
| 64 |
+
data = json.loads(line.strip())
|
| 65 |
+
# Ensure required fields are present
|
| 66 |
+
if 'description' in data and 'lyrics' in data:
|
| 67 |
+
music_list.append(data)
|
| 68 |
+
# If reached maximum count, stop reading
|
| 69 |
+
if max_count and len(music_list) >= max_count:
|
| 70 |
+
break
|
| 71 |
+
except json.JSONDecodeError:
|
| 72 |
+
continue
|
| 73 |
+
print(f"Successfully loaded {len(music_list)} songs")
|
| 74 |
+
return music_list
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def deduplicate_music(music_list, texts, model, threshold=0.90, output_file=None, save_interval=10000, matrix_save_dir=None):
|
| 78 |
+
"""
|
| 79 |
+
Deduplicate music data based on text similarity
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
music_list: List of music data
|
| 83 |
+
texts: List of texts for comparison
|
| 84 |
+
model: SentenceTransformer model
|
| 85 |
+
threshold: Similarity threshold
|
| 86 |
+
output_file: Output file path, if provided supports incremental saving
|
| 87 |
+
save_interval: Save every N valid songs processed
|
| 88 |
+
matrix_save_dir: Directory to save matrices
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Deduplicated music data list
|
| 92 |
+
"""
|
| 93 |
+
print(f"Computing embeddings for {len(texts)} texts...")
|
| 94 |
+
embeddings = model.encode(texts, convert_to_tensor=True, show_progress_bar=True)
|
| 95 |
+
|
| 96 |
+
print("Computing similarity matrix...")
|
| 97 |
+
cos_scores = util.pytorch_cos_sim(embeddings, embeddings)
|
| 98 |
+
|
| 99 |
+
# Save similarity matrix and embeddings
|
| 100 |
+
if matrix_save_dir:
|
| 101 |
+
os.makedirs(matrix_save_dir, exist_ok=True)
|
| 102 |
+
embeddings_path = os.path.join(matrix_save_dir, 'embeddings.pt')
|
| 103 |
+
cos_scores_path = os.path.join(matrix_save_dir, 'cos_scores.pt')
|
| 104 |
+
print(f"Saving embeddings to: {embeddings_path}")
|
| 105 |
+
torch.save(embeddings.cpu(), embeddings_path)
|
| 106 |
+
print(f"Saving similarity matrix to: {cos_scores_path}")
|
| 107 |
+
torch.save(cos_scores.cpu(), cos_scores_path)
|
| 108 |
+
print("Matrix saving complete!")
|
| 109 |
+
|
| 110 |
+
print(f"Deduplicating (threshold: {threshold})...")
|
| 111 |
+
keep_idx = []
|
| 112 |
+
removed = set()
|
| 113 |
+
|
| 114 |
+
# If output file provided, open in write mode
|
| 115 |
+
f = None
|
| 116 |
+
if output_file:
|
| 117 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
| 118 |
+
f = open(output_file, 'w', encoding='utf-8')
|
| 119 |
+
|
| 120 |
+
saved_count = 0
|
| 121 |
+
|
| 122 |
+
for i in tqdm(range(len(music_list)), desc="Deduplication progress"):
|
| 123 |
+
if i in removed:
|
| 124 |
+
continue
|
| 125 |
+
keep_idx.append(i)
|
| 126 |
+
|
| 127 |
+
# If incremental saving enabled, save every save_interval songs
|
| 128 |
+
if f and len(keep_idx) - saved_count >= save_interval:
|
| 129 |
+
# Save all valid songs from saved_count to current
|
| 130 |
+
for idx in range(saved_count, len(keep_idx)):
|
| 131 |
+
music = music_list[keep_idx[idx]]
|
| 132 |
+
f.write(json.dumps(music, ensure_ascii=False) + '\n')
|
| 133 |
+
f.flush() # Ensure write to disk
|
| 134 |
+
saved_count = len(keep_idx)
|
| 135 |
+
print(f"Saved {saved_count} valid songs to file")
|
| 136 |
+
|
| 137 |
+
for j in range(i+1, len(music_list)):
|
| 138 |
+
if cos_scores[i][j] > threshold:
|
| 139 |
+
removed.add(j)
|
| 140 |
+
|
| 141 |
+
# Save remaining valid songs
|
| 142 |
+
if f:
|
| 143 |
+
for idx in range(saved_count, len(keep_idx)):
|
| 144 |
+
music = music_list[keep_idx[idx]]
|
| 145 |
+
f.write(json.dumps(music, ensure_ascii=False) + '\n')
|
| 146 |
+
f.close()
|
| 147 |
+
print(f"Saved all {len(keep_idx)} valid songs to file")
|
| 148 |
+
|
| 149 |
+
deduped_music_list = [music_list[i] for i in keep_idx]
|
| 150 |
+
print(f"Deduplication complete: {len(music_list)} -> {len(deduped_music_list)} (removed {len(removed)} songs)")
|
| 151 |
+
|
| 152 |
+
return deduped_music_list
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def dedup_by_description_and_lyrics(input_file, output_file, threshold=0.95, max_count=None, device='cuda:1', save_interval=10000, matrix_save_dir=None):
|
| 156 |
+
"""
|
| 157 |
+
Method 2: Deduplicate based on description + lyrics
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
input_file: Path to input jsonl file
|
| 161 |
+
output_file: Path to output jsonl file
|
| 162 |
+
threshold: Similarity threshold
|
| 163 |
+
max_count: Maximum number to read, None means read all
|
| 164 |
+
device: Device to use, default cuda:1 (GPU1)
|
| 165 |
+
save_interval: Save every N valid songs processed, default 10000
|
| 166 |
+
matrix_save_dir: Directory to save matrices, if provided saves embeddings and similarity matrix
|
| 167 |
+
"""
|
| 168 |
+
print("\n========== Method 2: Deduplicate based on description + lyrics ==========")
|
| 169 |
+
print(f"Using device: {device}")
|
| 170 |
+
|
| 171 |
+
# Load data
|
| 172 |
+
music_list = load_music_data(input_file, max_count=max_count)
|
| 173 |
+
|
| 174 |
+
# Extract combined text from description + lyrics
|
| 175 |
+
combined_texts = []
|
| 176 |
+
for music in music_list:
|
| 177 |
+
description = music.get('description', '')
|
| 178 |
+
lyrics = music.get('lyrics', '')
|
| 179 |
+
# Clean lyrics, remove structure tags
|
| 180 |
+
cleaned_lyrics = clean_lyrics(lyrics)
|
| 181 |
+
# Concatenate description and cleaned lyrics (separated by delimiter)
|
| 182 |
+
combined_text = f"{description} [SEP] {cleaned_lyrics}"
|
| 183 |
+
combined_texts.append(combined_text)
|
| 184 |
+
|
| 185 |
+
# Load English model
|
| 186 |
+
print(f"Loading English model {MODEL_EN} to {device}...")
|
| 187 |
+
model = SentenceTransformer(MODEL_EN, device=device)
|
| 188 |
+
|
| 189 |
+
# Deduplicate (supports incremental saving)
|
| 190 |
+
deduped_music_list = deduplicate_music(
|
| 191 |
+
music_list,
|
| 192 |
+
combined_texts,
|
| 193 |
+
model,
|
| 194 |
+
threshold,
|
| 195 |
+
output_file=output_file,
|
| 196 |
+
save_interval=save_interval,
|
| 197 |
+
matrix_save_dir=matrix_save_dir
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Deduplication function already handled saving, just print info here
|
| 201 |
+
if output_file:
|
| 202 |
+
print(f"✓ Save complete! Remaining {len(deduped_music_list)} songs after deduplication\n")
|
| 203 |
+
else:
|
| 204 |
+
# If no output file provided, save once here (compatibility with old code)
|
| 205 |
+
print(f"Saving results to: {output_file}")
|
| 206 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
| 207 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 208 |
+
for music in deduped_music_list:
|
| 209 |
+
f.write(json.dumps(music, ensure_ascii=False) + '\n')
|
| 210 |
+
print(f"✓ Save complete! Remaining {len(deduped_music_list)} songs after deduplication\n")
|
| 211 |
+
|
| 212 |
+
return deduped_music_list
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
if __name__ == '__main__':
|
| 216 |
+
# Input file path
|
| 217 |
+
input_file = 'en_lrc_4w_single_pro_des.jsonl'
|
| 218 |
+
|
| 219 |
+
# Output file path
|
| 220 |
+
output_file = 'filter_en_single_4w(0.9).jsonl'
|
| 221 |
+
|
| 222 |
+
# Matrix save directory
|
| 223 |
+
matrix_save_dir = 'en_matrix'
|
| 224 |
+
|
| 225 |
+
# Set maximum read count (for testing, None means read all)
|
| 226 |
+
max_count = None # Test first 5 songs
|
| 227 |
+
|
| 228 |
+
# Deduplicate based on description + lyrics
|
| 229 |
+
print("\nDeduplicating based on description + lyrics")
|
| 230 |
+
dedup_by_description_and_lyrics(
|
| 231 |
+
input_file,
|
| 232 |
+
output_file,
|
| 233 |
+
threshold=0.90,
|
| 234 |
+
max_count=max_count,
|
| 235 |
+
device='cuda:7',
|
| 236 |
+
save_interval=10000, # Save every 10000 valid songs
|
| 237 |
+
matrix_save_dir=matrix_save_dir # Save similarity matrix
|
| 238 |
+
)
|
| 239 |
+
print(f"\nComplete! Results saved to: {output_file}")
|
| 240 |
+
print(f"Similarity matrix saved to: {matrix_save_dir}")
|
| 241 |
+
# Test lyrics cleaning effect
|
| 242 |
+
# print("\n========== Test Lyrics Cleaning Effect ==========")
|
| 243 |
+
# music_list = load_music_data(input_file, max_count=max_count)
|
| 244 |
+
|
| 245 |
+
# print("\n" + "="*80)
|
| 246 |
+
# for i, music in enumerate(music_list, 1):
|
| 247 |
+
# print(f"\n[Song {i}]")
|
| 248 |
+
# print(f"Description: {music.get('description', '')}")
|
| 249 |
+
# print("\n--- Original Lyrics ---")
|
| 250 |
+
# original_lyrics = music.get('lyrics', '')
|
| 251 |
+
# print(original_lyrics[:500] + "..." if len(original_lyrics) > 500 else original_lyrics)
|
| 252 |
+
# print("\n--- Cleaned Lyrics ---")
|
| 253 |
+
# cleaned_lyrics = clean_lyrics(original_lyrics)
|
| 254 |
+
# print(cleaned_lyrics)
|
| 255 |
+
# print("\n" + "-"*80)
|
| 256 |
+
|
data_pipeline/lyrics_gene/gen_lyrics_cn.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
import random
|
| 5 |
+
import re
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import threading
|
| 9 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 10 |
+
|
| 11 |
+
# Set environment variables
|
| 12 |
+
# Note: Set these environment variables before running the script
|
| 13 |
+
# export OPENAI_API_KEY="your-api-key"
|
| 14 |
+
# export OPENAI_BASE_URL="https://api.openai.com/v1" # or your custom API URL
|
| 15 |
+
if not os.environ.get("OPENAI_API_KEY"):
|
| 16 |
+
os.environ["OPENAI_API_KEY"] = "" # Replace with your API key or set via environment variable
|
| 17 |
+
if not os.environ.get("OPENAI_BASE_URL"):
|
| 18 |
+
os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1" # Replace with API URL or set via environment variable
|
| 19 |
+
|
| 20 |
+
# Initialize client
|
| 21 |
+
client = OpenAI()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _extract_lyrics_timestamps(lyrics_text):
|
| 25 |
+
"""
|
| 26 |
+
Extract timestamps from lyrics and convert to seconds
|
| 27 |
+
Args:
|
| 28 |
+
lyrics_text: Lyrics string
|
| 29 |
+
Returns:
|
| 30 |
+
List[float]: Timestamps in order (seconds)
|
| 31 |
+
"""
|
| 32 |
+
if not isinstance(lyrics_text, str):
|
| 33 |
+
return []
|
| 34 |
+
pattern = re.compile(r'\[(\d{2}):(\d{2})(?:\.(\d{2}))?\]')
|
| 35 |
+
timestamps = []
|
| 36 |
+
for match in pattern.finditer(lyrics_text):
|
| 37 |
+
minutes = int(match.group(1))
|
| 38 |
+
seconds = int(match.group(2))
|
| 39 |
+
fraction = match.group(3)
|
| 40 |
+
total_seconds = minutes * 60 + seconds
|
| 41 |
+
if fraction is not None:
|
| 42 |
+
divisor = 100 if len(fraction) == 2 else 10 ** len(fraction)
|
| 43 |
+
total_seconds += int(fraction) / divisor
|
| 44 |
+
timestamps.append(total_seconds)
|
| 45 |
+
return timestamps
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _validate_timestamps(lyrics_text, min_last_timestamp=170, max_interval=35):
|
| 49 |
+
"""
|
| 50 |
+
Validate if lyrics timestamps meet requirements
|
| 51 |
+
Args:
|
| 52 |
+
lyrics_text: Lyrics string
|
| 53 |
+
min_last_timestamp: Minimum value of last timestamp (seconds)
|
| 54 |
+
max_interval: Maximum interval between last two timestamps (seconds)
|
| 55 |
+
Returns:
|
| 56 |
+
bool: Whether validation passes
|
| 57 |
+
"""
|
| 58 |
+
timestamps = _extract_lyrics_timestamps(lyrics_text)
|
| 59 |
+
if len(timestamps) < 2:
|
| 60 |
+
print("Validation failed: Timestamp count less than 2")
|
| 61 |
+
return False
|
| 62 |
+
last = timestamps[-1]
|
| 63 |
+
second_last = timestamps[-2]
|
| 64 |
+
if last < min_last_timestamp:
|
| 65 |
+
print(f"Validation failed: Last timestamp {last:.2f}s is less than {min_last_timestamp}s")
|
| 66 |
+
return False
|
| 67 |
+
if last - second_last > max_interval:
|
| 68 |
+
print(f"Validation failed: Interval between last two timestamps {last - second_last:.2f}s is greater than {max_interval}s")
|
| 69 |
+
return False
|
| 70 |
+
return True
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def chat_gpt(text, model='gpt-4o-mini'):
|
| 74 |
+
while True:
|
| 75 |
+
try:
|
| 76 |
+
# Call OpenAI chat completions API
|
| 77 |
+
completion = client.chat.completions.create(
|
| 78 |
+
model=model, # Use GPT-4o-mini model
|
| 79 |
+
messages=[
|
| 80 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 81 |
+
{"role": "user", "content": text}
|
| 82 |
+
]
|
| 83 |
+
)
|
| 84 |
+
# Get reply content
|
| 85 |
+
if getattr(completion.choices[0].message, 'content', None):
|
| 86 |
+
content = completion.choices[0].message.content.strip()
|
| 87 |
+
return content
|
| 88 |
+
else:
|
| 89 |
+
print('error_wait_2s')
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"Error: {e}")
|
| 92 |
+
time.sleep(2)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def chat_gpt_call(text, model='gpt-4o-mini'):
|
| 96 |
+
# Call OpenAI chat completions API
|
| 97 |
+
completion = client.chat.completions.create(
|
| 98 |
+
model=model, # Use GPT-4o-mini model
|
| 99 |
+
messages=[
|
| 100 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 101 |
+
{"role": "user", "content": text}
|
| 102 |
+
]
|
| 103 |
+
)
|
| 104 |
+
# Get reply content
|
| 105 |
+
if getattr(completion.choices[0].message, 'content', None):
|
| 106 |
+
content = completion.choices[0].message.content.strip()
|
| 107 |
+
return content
|
| 108 |
+
else:
|
| 109 |
+
print('error_wait_2s')
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def generate_music_descriptions(all_music_data, index_pool, output_file, file_lock, sample_size=20, model='gpt-4o-mini', max_retries=0):
|
| 113 |
+
"""
|
| 114 |
+
Read music data file, randomly sample, call GPT to generate new music descriptions and lyrics
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
all_music_data: All music data list
|
| 118 |
+
index_pool: Index pool object (thread-safe)
|
| 119 |
+
output_file: Output jsonl file path
|
| 120 |
+
file_lock: File write lock
|
| 121 |
+
sample_size: Number of random samples
|
| 122 |
+
model: Model name to use
|
| 123 |
+
max_retries: Maximum retry count
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
(used_indices, success_count): List of used indices and count of successful generations
|
| 127 |
+
"""
|
| 128 |
+
# duration_ranges = [
|
| 129 |
+
# ("3.00", "3.15", 50), ("3.15", "3.30", 50), ("3.30", "3.45", 60),
|
| 130 |
+
# ("3.45", "4.00", 60),
|
| 131 |
+
# ("4.00", "4.15", 70), ("4.15", "4.30", 70), ("4.30", "4.45", 70)
|
| 132 |
+
# ]
|
| 133 |
+
duration_ranges = [
|
| 134 |
+
("3.00", "3.15", 60), ("3.15", "3.30", 70), ("3.30", "3.45", 80),("3.45", "4.00", 90),
|
| 135 |
+
("4.00", "4.15", 100), ("4.15", "4.30", 100), ("4.30", "4.45", 100)
|
| 136 |
+
]
|
| 137 |
+
selected_range = random.choice(duration_ranges)
|
| 138 |
+
require_length = selected_range[2]
|
| 139 |
+
|
| 140 |
+
# Directly convert to timestamp format (strictly corresponds to left and right ends of tuple)
|
| 141 |
+
start_timestamp = f"[{selected_range[0].replace('.', ':')}.00]"
|
| 142 |
+
end_timestamp = f"[{selected_range[1].replace('.', ':')}.00]"
|
| 143 |
+
|
| 144 |
+
# Generate duration description
|
| 145 |
+
start_duration = f"{selected_range[0].replace('.', '分')}秒"
|
| 146 |
+
end_duration = f"{selected_range[1].replace('.', '分')}秒"
|
| 147 |
+
|
| 148 |
+
# Generate random timestamps for examples (randomly generated within time range)
|
| 149 |
+
# Parse time string to minutes and seconds
|
| 150 |
+
start_parts = selected_range[0].split('.')
|
| 151 |
+
end_parts = selected_range[1].split('.')
|
| 152 |
+
|
| 153 |
+
start_minutes = int(start_parts[0])
|
| 154 |
+
start_seconds = int(start_parts[1])
|
| 155 |
+
end_minutes = int(end_parts[0])
|
| 156 |
+
end_seconds = int(end_parts[1])
|
| 157 |
+
|
| 158 |
+
# Convert to total seconds
|
| 159 |
+
start_total_seconds = start_minutes * 60 + start_seconds
|
| 160 |
+
end_total_seconds = end_minutes * 60 + end_seconds
|
| 161 |
+
|
| 162 |
+
# Randomly generate within range
|
| 163 |
+
example1_seconds = random.randint(start_total_seconds, end_total_seconds)
|
| 164 |
+
example2_seconds = random.randint(start_total_seconds, end_total_seconds)
|
| 165 |
+
|
| 166 |
+
example1_minutes = example1_seconds // 60
|
| 167 |
+
example1_secs = example1_seconds % 60
|
| 168 |
+
example2_minutes = example2_seconds // 60
|
| 169 |
+
example2_secs = example2_seconds % 60
|
| 170 |
+
|
| 171 |
+
example1_timestamp = f"[{example1_minutes:02d}:{example1_secs:02d}.00]"
|
| 172 |
+
example2_timestamp = f"[{example2_minutes:02d}:{example2_secs:02d}.00]"
|
| 173 |
+
|
| 174 |
+
# Get sample indices from index pool (thread-safe)
|
| 175 |
+
selected_indices = index_pool.get_indices(sample_size)
|
| 176 |
+
if not selected_indices:
|
| 177 |
+
return [], 0
|
| 178 |
+
|
| 179 |
+
sample_data = [all_music_data[i] for i in selected_indices]
|
| 180 |
+
|
| 181 |
+
# Extract all unique styles
|
| 182 |
+
styles = []
|
| 183 |
+
for data in sample_data:
|
| 184 |
+
style = data.get('style', '')
|
| 185 |
+
if style and style not in styles:
|
| 186 |
+
styles.append(style)
|
| 187 |
+
|
| 188 |
+
styles_text = "、".join(styles)
|
| 189 |
+
|
| 190 |
+
# Build example text - include all sampled data (excluding style)
|
| 191 |
+
examples = []
|
| 192 |
+
for i, data in enumerate(sample_data, 1):
|
| 193 |
+
lyrics_text = " ".join(data.get('lyrics', [])) if isinstance(data.get('lyrics'), list) else data.get('lyrics', '')
|
| 194 |
+
description = data.get('description', '')
|
| 195 |
+
examples.append(f"示例{i}:\ndescription: {description}\nlyrics: {lyrics_text}")
|
| 196 |
+
|
| 197 |
+
examples_text = "\n\n".join(examples)
|
| 198 |
+
|
| 199 |
+
prompt = f"""生成2首完整的歌曲,每首歌必须满足以下硬性指标:
|
| 200 |
+
- 严禁生成小于{require_length}行的歌词!
|
| 201 |
+
- 每首歌的歌词行数必须严格大于{require_length}行,这是硬性要求!
|
| 202 |
+
- 最后一句时间戳必须在{start_timestamp}到{end_timestamp}之间
|
| 203 |
+
- 两首歌的时长、行数必须有差异,严禁最后的时间戳均相同
|
| 204 |
+
- 相邻歌词行的时间戳间隔不得超过10秒!必须保证时间戳连续自然递进
|
| 205 |
+
- 严禁出现如"[03:25.00]在心中[04:25.00]最后一行歌词"的生硬间隔,严禁超过10s的间隔
|
| 206 |
+
如果生成的歌曲不满足以上任意一项,则视为不合格,请重新生成。
|
| 207 |
+
请生成2首新的、具有多样性的音乐描述和LRC格式歌词,语言为中文。
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
创作要求:
|
| 211 |
+
1.风格与流派需要确保多样性。
|
| 212 |
+
2.Description标签化要求(必须严格遵守):
|
| 213 |
+
description字段必须使用结构化的标签格式,包括以下标签,用逗号分隔:
|
| 214 |
+
- 音乐风格标签
|
| 215 |
+
- 音乐流派标签
|
| 216 |
+
- 乐器标签
|
| 217 |
+
- 情感基调标签
|
| 218 |
+
- 氛围标签
|
| 219 |
+
- 演唱方式和人声标签,仅限男声或女生二选一,单人独唱
|
| 220 |
+
注意:每个标签简洁明了,多个同类标签可用斜杠分隔(如"钢琴/小提琴")
|
| 221 |
+
3.歌词创造力:lyrics应该具有深度和艺术性:
|
| 222 |
+
- 主题可以涉及爱情、人生、社会、自然、哲思、梦想、回忆等各个方面
|
| 223 |
+
- 运用丰富的文学手法:比喻、意象、对比、排比等
|
| 224 |
+
- 情感真挚,注重韵律和节奏感
|
| 225 |
+
- 可以是叙事性、抒情性或意识流风格
|
| 226 |
+
4.歌词结构和长度要求(必须严格遵守):
|
| 227 |
+
- lyrics必须按照以下结构组织,并使用段落标签标注每个部分
|
| 228 |
+
- 结构顺序必须严格遵循该顺序,共8个段落标签:[Verse 1]主歌1 → [Pre-Chorus]预副歌 → [Chorus]副歌 → [Verse 2]主歌2 → [Pre-Chorus]预副歌 → [Chorus]副歌 → [Bridge]桥段 → [Chorus (Outro)]副歌(结尾)
|
| 229 |
+
- 一首歌词的段落标签只有8个,即[Verse 1]和[Verse 2]只出现一次,[Pre-Chorus]和[Chorus]各出现两次,[Bridge]和[Chorus (Outro)]各出现一次,禁止额外添加或重复更多的段落标签
|
| 230 |
+
- 每个段落标签(如[Verse 1]、[Chorus]等)必须独占一行,后面紧跟该段落的LRC格式歌词
|
| 231 |
+
- 段落之间用空行分隔
|
| 232 |
+
- **总行数要求**:整首歌必须包含至少{require_length}行带时间戳的歌词(不包括段落标签行和空行)
|
| 233 |
+
5.LRC格式强制规则(必须��格遵守):
|
| 234 |
+
- 每行歌词格式必须为 `[mm:ss.xx]歌词内容`,时间戳与歌词间无空格,歌词内容需完整连贯
|
| 235 |
+
- **每一行只能包含一小句歌词**,遇到逗号、句号等标点符号时必须换行。
|
| 236 |
+
- **严禁将多句歌词合并在同一行**
|
| 237 |
+
- 时间戳需自然分布,**第一句歌词起始时间不得为 [00:00.00]**,需考虑前奏空白(建议从[00:05.00]到[00:15.00]之间开始)
|
| 238 |
+
- 时间戳间隔要求多样性:每首歌内部的时间戳间隔必须多样化,多采用小数点数间隔,严禁使用固定间隔:
|
| 239 |
+
* 同一首歌内必须包含多种不同的间隔,不要所有句子都使用相同间隔(如不要全部都是4秒间隔)
|
| 240 |
+
* 根据歌词内容的情感强度和音乐节拍来动态调整间隔
|
| 241 |
+
* 相邻歌词行的间隔应该有所变化,体现音乐的节奏起伏
|
| 242 |
+
- 时间戳分配应根据歌曲的风格、情感、节奏来合理推测,而非机械地按照歌词长度分配
|
| 243 |
+
- 每行歌词长度应自然变化,切勿长度一致
|
| 244 |
+
- **歌曲总时长必须达到{start_duration}到{end_duration}(即最后一句时间戳必须在{start_timestamp}到{end_timestamp}之间)这是硬性要求!**
|
| 245 |
+
6.歌词长度要求:lyrics字段的歌词行数必须大于{require_length}行,若生成长度过短请重新生成。
|
| 246 |
+
7.独特性和原创性:每首作品都应该是独一无二的,避免简单重复示例的内容。
|
| 247 |
+
8.格式要求:
|
| 248 |
+
- 直接返回JSON数组格式,包含2个歌曲对象,每个对象只有description和lyrics两个字段
|
| 249 |
+
- description字段:必须是标签格式,不是叙述性文本
|
| 250 |
+
- lyrics字段:带段落标签的LRC格式字符串
|
| 251 |
+
- 严禁在JSON中插入任何额外的符号、标记、注释或说明文字
|
| 252 |
+
|
| 253 |
+
LRC格式示例(带段落标签):
|
| 254 |
+
[Verse 1]
|
| 255 |
+
[00:08.00]第一句歌词
|
| 256 |
+
[00:12.50]第二句歌词
|
| 257 |
+
[00:17.20]第三句歌词
|
| 258 |
+
|
| 259 |
+
[Pre-Chorus]
|
| 260 |
+
[00:22.00]预副歌歌词
|
| 261 |
+
[00:26.50]预副歌歌词
|
| 262 |
+
|
| 263 |
+
[Chorus]
|
| 264 |
+
[00:31.00]副歌歌词
|
| 265 |
+
[00:35.50]副歌歌词
|
| 266 |
+
|
| 267 |
+
负面示例(禁止出现):
|
| 268 |
+
- 错误:[01:30.00](钢琴间奏) - 禁止在时间戳后加括号注释
|
| 269 |
+
- 错误:[00:00.00]开始的歌词 - 第一句不能从00:00.00开始
|
| 270 |
+
- 错误: [00:05.00]在那片熟悉的田野,阳光洒满金色的麦穗 - 严禁多句歌词放在同一行
|
| 271 |
+
|
| 272 |
+
现在,请充分发挥你的创造力,生成2首全新的、完整的音乐描述和LRC格式歌词作品。
|
| 273 |
+
特别提醒:每首歌必须是完整歌曲,不要缩写或省略!必须包含完整的8个段落(Verse 1, Pre-Chorus, Chorus, Verse 2, Pre-Chorus, Chorus, Bridge, Chorus Outro),严格确保大于{require_length}行歌词。
|
| 274 |
+
|
| 275 |
+
直接返回JSON数组格式:
|
| 276 |
+
[
|
| 277 |
+
{{"description": "...", "lyrics": "..."}},
|
| 278 |
+
{{"description": "...", "lyrics": "..."}}
|
| 279 |
+
]"""
|
| 280 |
+
# Try to generate with retry mechanism
|
| 281 |
+
for attempt in range(max_retries + 1):
|
| 282 |
+
try:
|
| 283 |
+
# Call OpenAI API
|
| 284 |
+
completion = client.chat.completions.create(
|
| 285 |
+
model=model,
|
| 286 |
+
messages=[
|
| 287 |
+
{"role": "system", "content": f"You are a creative music lyricist and composer. Please generate diverse and creative music tag-based descriptions and LRC format lyrics with song structure tags. CRITICAL REQUIREMENTS: 1) Description must be structured tags separated by commas, NOT narrative text. 2) Return ONLY pure, valid JSON format without any extra symbols, markers, or comments. 3) Each song must include structure tags like [Verse 1], [Chorus], [Bridge], etc., followed by LRC format lyrics [mm:ss.xx]lyric_content. 4) MANDATORY: Each song must have MORE than {require_length} lines of lyrics with timestamps. "},
|
| 288 |
+
{"role": "user", "content": prompt}
|
| 289 |
+
],
|
| 290 |
+
n=1,
|
| 291 |
+
temperature=1.0,
|
| 292 |
+
)
|
| 293 |
+
#print(prompt)
|
| 294 |
+
# Extract all responses
|
| 295 |
+
results = []
|
| 296 |
+
filtered_count = 0
|
| 297 |
+
last_content = None
|
| 298 |
+
|
| 299 |
+
for i, choice in enumerate(completion.choices, 1):
|
| 300 |
+
try:
|
| 301 |
+
content = choice.message.content.strip()
|
| 302 |
+
last_content = content
|
| 303 |
+
print(f"\n=== GPT Response {i} ===")
|
| 304 |
+
print(content)
|
| 305 |
+
print("=" * 50)
|
| 306 |
+
# Try to extract JSON content
|
| 307 |
+
if "```json" in content:
|
| 308 |
+
content = content.split("```json")[1].split("```")[0].strip()
|
| 309 |
+
elif "```" in content:
|
| 310 |
+
content = content.split("```")[1].split("```")[0].strip()
|
| 311 |
+
|
| 312 |
+
# Clean trailing commas in JSON (extra commas)
|
| 313 |
+
# Remove commas after last element of object/array
|
| 314 |
+
content = re.sub(r',(\s*[}\]])', r'\1', content)
|
| 315 |
+
|
| 316 |
+
# Parse JSON array
|
| 317 |
+
result_array = json.loads(content)
|
| 318 |
+
|
| 319 |
+
# Ensure it's a list
|
| 320 |
+
if isinstance(result_array, list):
|
| 321 |
+
# Validate each object in array
|
| 322 |
+
for song in result_array:
|
| 323 |
+
if isinstance(song, dict) and 'description' in song and 'lyrics' in song:
|
| 324 |
+
if _validate_timestamps(song.get('lyrics', '')):
|
| 325 |
+
results.append(song)
|
| 326 |
+
else:
|
| 327 |
+
filtered_count += 1
|
| 328 |
+
# If returned a single object (compatibility with old format)
|
| 329 |
+
elif isinstance(result_array, dict) and 'description' in result_array and 'lyrics' in result_array:
|
| 330 |
+
if _validate_timestamps(result_array.get('lyrics', '')):
|
| 331 |
+
results.append(result_array)
|
| 332 |
+
else:
|
| 333 |
+
filtered_count += 1
|
| 334 |
+
|
| 335 |
+
except json.JSONDecodeError:
|
| 336 |
+
continue
|
| 337 |
+
|
| 338 |
+
if filtered_count:
|
| 339 |
+
print(f"Total {filtered_count} songs filtered due to timestamp validation failure")
|
| 340 |
+
|
| 341 |
+
# Print parsing results
|
| 342 |
+
print(f"\nParsing complete, results length: {len(results)}")
|
| 343 |
+
print(f"Results content: {results}")
|
| 344 |
+
print(start_duration, end_duration,example1_timestamp,example2_timestamp,require_length)
|
| 345 |
+
|
| 346 |
+
# If parsed result length is not 5, write model response content to test.txt
|
| 347 |
+
if len(results) != 2:
|
| 348 |
+
print(f"Warning: Parsed result length is not 2, actual is {len(results)}, will write to test.txt")
|
| 349 |
+
with open('test.txt', 'w', encoding='utf-8') as f:
|
| 350 |
+
if last_content is not None:
|
| 351 |
+
f.write(last_content)
|
| 352 |
+
print("Written to test.txt file")
|
| 353 |
+
|
| 354 |
+
# Check if successfully generated 50 songs (10 responses * 5 each)
|
| 355 |
+
if len(results) >= 50:
|
| 356 |
+
# Append save results to file (use lock to ensure thread safety)
|
| 357 |
+
with file_lock:
|
| 358 |
+
with open(output_file, 'a', encoding='utf-8') as f:
|
| 359 |
+
for result in results[:50]: # Only save first 50 songs
|
| 360 |
+
f.write(json.dumps(result, ensure_ascii=False) + '\n')
|
| 361 |
+
|
| 362 |
+
return selected_indices, min(len(results), 50)
|
| 363 |
+
elif attempt < max_retries:
|
| 364 |
+
print(f"Only successfully parsed {len(results)}/50 songs, retrying...")
|
| 365 |
+
time.sleep(2)
|
| 366 |
+
else:
|
| 367 |
+
# Last attempt, save even if not 50 songs
|
| 368 |
+
if len(results) > 0:
|
| 369 |
+
with file_lock:
|
| 370 |
+
with open(output_file, 'a', encoding='utf-8') as f:
|
| 371 |
+
for result in results:
|
| 372 |
+
f.write(json.dumps(result, ensure_ascii=False) + '\n')
|
| 373 |
+
return selected_indices, len(results)
|
| 374 |
+
|
| 375 |
+
except Exception as e:
|
| 376 |
+
if attempt < max_retries:
|
| 377 |
+
print(f"Error occurred during generation: {e}, retrying...")
|
| 378 |
+
time.sleep(2)
|
| 379 |
+
else:
|
| 380 |
+
print(f"Generation failed: {e}")
|
| 381 |
+
return selected_indices, 0
|
| 382 |
+
|
| 383 |
+
return selected_indices, 0
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class IndexPool:
|
| 387 |
+
"""Thread-safe index pool with automatic reset support"""
|
| 388 |
+
|
| 389 |
+
def __init__(self, total_size, selected_file):
|
| 390 |
+
self.total_size = total_size
|
| 391 |
+
self.selected_file = selected_file
|
| 392 |
+
self.lock = threading.Lock()
|
| 393 |
+
self.available_indices = []
|
| 394 |
+
self.selected_indices = set()
|
| 395 |
+
self.reset_count = 0 # Record reset count
|
| 396 |
+
|
| 397 |
+
# Load selected indices from file
|
| 398 |
+
self._load_selected_indices()
|
| 399 |
+
# Initialize available indices
|
| 400 |
+
self._reset_pool()
|
| 401 |
+
|
| 402 |
+
def _load_selected_indices(self):
|
| 403 |
+
"""Load selected indices from file"""
|
| 404 |
+
if os.path.exists(self.selected_file):
|
| 405 |
+
with open(self.selected_file, 'r', encoding='utf-8') as f:
|
| 406 |
+
for line in f:
|
| 407 |
+
self.selected_indices.add(int(line.strip()))
|
| 408 |
+
|
| 409 |
+
def _reset_pool(self):
|
| 410 |
+
"""Reset index pool"""
|
| 411 |
+
# Calculate available indices
|
| 412 |
+
self.available_indices = [i for i in range(self.total_size) if i not in self.selected_indices]
|
| 413 |
+
random.shuffle(self.available_indices) # Shuffle order
|
| 414 |
+
|
| 415 |
+
if len(self.available_indices) == 0:
|
| 416 |
+
# If no available indices, all have been used, reset selected_indices
|
| 417 |
+
self.reset_count += 1
|
| 418 |
+
print(f"\nIndex pool exhausted, resetting pool for the {self.reset_count}th time, re-selecting from {self.total_size} songs")
|
| 419 |
+
self.selected_indices.clear()
|
| 420 |
+
self.available_indices = list(range(self.total_size))
|
| 421 |
+
random.shuffle(self.available_indices)
|
| 422 |
+
|
| 423 |
+
def get_indices(self, count):
|
| 424 |
+
"""
|
| 425 |
+
Thread-safe get specified number of indices
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
count: Number of indices needed
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
List of selected indices
|
| 432 |
+
"""
|
| 433 |
+
with self.lock:
|
| 434 |
+
# Check if pool needs to be reset
|
| 435 |
+
if len(self.available_indices) < count:
|
| 436 |
+
self._reset_pool()
|
| 437 |
+
|
| 438 |
+
# Get indices
|
| 439 |
+
selected = self.available_indices[:count]
|
| 440 |
+
self.available_indices = self.available_indices[count:]
|
| 441 |
+
|
| 442 |
+
# Add to selected set
|
| 443 |
+
for idx in selected:
|
| 444 |
+
self.selected_indices.add(idx)
|
| 445 |
+
|
| 446 |
+
# Write to file
|
| 447 |
+
with open(self.selected_file, 'a', encoding='utf-8') as f:
|
| 448 |
+
for idx in selected:
|
| 449 |
+
f.write(f"{idx}\n")
|
| 450 |
+
|
| 451 |
+
return selected
|
| 452 |
+
|
| 453 |
+
def get_stats(self):
|
| 454 |
+
"""Get statistics"""
|
| 455 |
+
with self.lock:
|
| 456 |
+
return {
|
| 457 |
+
'available': len(self.available_indices),
|
| 458 |
+
'selected': len(self.selected_indices),
|
| 459 |
+
'reset_count': self.reset_count
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def batch_generate_music(input_file, output_file, selected_file, total_songs=1000, sample_size=20, model='gpt-4o-mini', num_threads=10):
|
| 464 |
+
"""
|
| 465 |
+
Batch generate music descriptions and lyrics (multi-threaded version)
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
input_file: Path to input jsonl file
|
| 469 |
+
output_file: Path to output jsonl file
|
| 470 |
+
selected_file: Path to file recording selected indices
|
| 471 |
+
total_songs: Total number of songs to generate
|
| 472 |
+
sample_size: Number of samples to extract each time
|
| 473 |
+
model: Model name to use
|
| 474 |
+
num_threads: Number of threads
|
| 475 |
+
"""
|
| 476 |
+
# Load all music data
|
| 477 |
+
print("Loading music data...")
|
| 478 |
+
all_music_data = []
|
| 479 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 480 |
+
for line in f:
|
| 481 |
+
data = json.loads(line.strip())
|
| 482 |
+
all_music_data.append(data)
|
| 483 |
+
print(f"Loaded {len(all_music_data)} songs")
|
| 484 |
+
|
| 485 |
+
# Create thread-safe index pool
|
| 486 |
+
index_pool = IndexPool(len(all_music_data), selected_file)
|
| 487 |
+
stats = index_pool.get_stats()
|
| 488 |
+
print(f"Currently selected indices: {stats['selected']}")
|
| 489 |
+
print(f"Currently available indices: {stats['available']}")
|
| 490 |
+
|
| 491 |
+
# Calculate number of calls needed (5 songs per call)
|
| 492 |
+
num_iterations = (total_songs + 1) // 2 # Round up
|
| 493 |
+
print(f"Need to call {num_iterations} times to generate approximately {total_songs} songs (5 per call)")
|
| 494 |
+
print(f"Using {num_threads} threads for parallel processing\n")
|
| 495 |
+
|
| 496 |
+
# Create file write lock
|
| 497 |
+
file_lock = threading.Lock()
|
| 498 |
+
|
| 499 |
+
# Statistics
|
| 500 |
+
total_generated = 0
|
| 501 |
+
generated_lock = threading.Lock()
|
| 502 |
+
|
| 503 |
+
def worker_task(task_id):
|
| 504 |
+
"""Worker thread task"""
|
| 505 |
+
try:
|
| 506 |
+
used_indices, success_count = generate_music_descriptions(
|
| 507 |
+
all_music_data=all_music_data,
|
| 508 |
+
index_pool=index_pool,
|
| 509 |
+
output_file=output_file,
|
| 510 |
+
file_lock=file_lock,
|
| 511 |
+
sample_size=sample_size,
|
| 512 |
+
model=model,
|
| 513 |
+
max_retries=0 # Retry
|
| 514 |
+
)
|
| 515 |
+
return success_count
|
| 516 |
+
except Exception as e:
|
| 517 |
+
print(f"Task {task_id} failed: {e}")
|
| 518 |
+
return 0
|
| 519 |
+
|
| 520 |
+
# Use thread pool and progress bar
|
| 521 |
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
| 522 |
+
# Submit all tasks
|
| 523 |
+
futures = {executor.submit(worker_task, i): i for i in range(num_iterations)}
|
| 524 |
+
|
| 525 |
+
# Use tqdm to show progress
|
| 526 |
+
with tqdm(total=num_iterations, desc="Generation progress", unit="batch") as pbar:
|
| 527 |
+
for future in as_completed(futures):
|
| 528 |
+
success_count = future.result()
|
| 529 |
+
|
| 530 |
+
with generated_lock:
|
| 531 |
+
total_generated += success_count
|
| 532 |
+
|
| 533 |
+
# Get current statistics
|
| 534 |
+
stats = index_pool.get_stats()
|
| 535 |
+
|
| 536 |
+
# Update progress bar
|
| 537 |
+
pbar.set_postfix({
|
| 538 |
+
'Batch': f'{success_count}/5',
|
| 539 |
+
'Total': total_generated,
|
| 540 |
+
'Remaining': stats['available'],
|
| 541 |
+
'Resets': stats['reset_count']
|
| 542 |
+
})
|
| 543 |
+
pbar.update(1)
|
| 544 |
+
|
| 545 |
+
# Final statistics
|
| 546 |
+
stats = index_pool.get_stats()
|
| 547 |
+
print(f"\nGeneration complete!")
|
| 548 |
+
print(f"Total generated: {total_generated} songs")
|
| 549 |
+
print(f"Used {stats['selected']} indices")
|
| 550 |
+
print(f"Remaining available indices: {stats['available']}")
|
| 551 |
+
print(f"Pool reset count: {stats['reset_count']}")
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
if __name__ == '__main__':
|
| 555 |
+
input_file = 'tagged_musics.jsonl'
|
| 556 |
+
output_file = 'generate_lrc_5mini.jsonl'
|
| 557 |
+
selected_file = 'selected.txt'
|
| 558 |
+
# n=1, max_retries=0, sample 10 songs each time, generate 5 new songs
|
| 559 |
+
batch_generate_music(
|
| 560 |
+
input_file=input_file,
|
| 561 |
+
output_file=output_file,
|
| 562 |
+
selected_file=selected_file,
|
| 563 |
+
total_songs=10,
|
| 564 |
+
sample_size=4,
|
| 565 |
+
model='gpt-5-mini',
|
| 566 |
+
num_threads=5 # Test with 1 thread first
|
| 567 |
+
)
|
| 568 |
+
# Append to txt file
|
data_pipeline/lyrics_gene/gen_lyrics_en.py
ADDED
|
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
import random
|
| 5 |
+
import re
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import threading
|
| 9 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 10 |
+
|
| 11 |
+
# Set environment variables
|
| 12 |
+
# Note: Set these environment variables before running the script
|
| 13 |
+
# export OPENAI_API_KEY="your-api-key"
|
| 14 |
+
# export OPENAI_BASE_URL="https://api.openai.com/v1" # or your custom API URL
|
| 15 |
+
if not os.environ.get("OPENAI_API_KEY"):
|
| 16 |
+
os.environ["OPENAI_API_KEY"] = "" # Replace with your API key or set via environment variable
|
| 17 |
+
if not os.environ.get("OPENAI_BASE_URL"):
|
| 18 |
+
os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1" # Replace with API URL or set via environment variable
|
| 19 |
+
|
| 20 |
+
# Initialize client
|
| 21 |
+
client = OpenAI()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _extract_lyrics_timestamps(lyrics_text):
|
| 25 |
+
"""
|
| 26 |
+
Extract timestamps from lyrics and convert to seconds
|
| 27 |
+
Args:
|
| 28 |
+
lyrics_text: Lyrics string
|
| 29 |
+
Returns:
|
| 30 |
+
List[float]: Timestamps in order (seconds)
|
| 31 |
+
"""
|
| 32 |
+
if not isinstance(lyrics_text, str):
|
| 33 |
+
return []
|
| 34 |
+
pattern = re.compile(r'\[(\d{2}):(\d{2})(?:\.(\d{2}))?\]')
|
| 35 |
+
timestamps = []
|
| 36 |
+
for match in pattern.finditer(lyrics_text):
|
| 37 |
+
minutes = int(match.group(1))
|
| 38 |
+
seconds = int(match.group(2))
|
| 39 |
+
fraction = match.group(3)
|
| 40 |
+
total_seconds = minutes * 60 + seconds
|
| 41 |
+
if fraction is not None:
|
| 42 |
+
divisor = 100 if len(fraction) == 2 else 10 ** len(fraction)
|
| 43 |
+
total_seconds += int(fraction) / divisor
|
| 44 |
+
timestamps.append(total_seconds)
|
| 45 |
+
return timestamps
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _validate_timestamps(lyrics_text, min_last_timestamp=170, max_interval=35):
|
| 49 |
+
"""
|
| 50 |
+
Validate if timestamps in lyrics meet requirements
|
| 51 |
+
Args:
|
| 52 |
+
lyrics_text: Lyrics string
|
| 53 |
+
min_last_timestamp: Minimum value of last timestamp (seconds)
|
| 54 |
+
max_interval: Maximum interval between last two timestamps (seconds)
|
| 55 |
+
Returns:
|
| 56 |
+
bool: Whether validation passed
|
| 57 |
+
"""
|
| 58 |
+
timestamps = _extract_lyrics_timestamps(lyrics_text)
|
| 59 |
+
if len(timestamps) < 2:
|
| 60 |
+
print("Validation failed: Timestamp count less than 2")
|
| 61 |
+
return False
|
| 62 |
+
last = timestamps[-1]
|
| 63 |
+
second_last = timestamps[-2]
|
| 64 |
+
if last < min_last_timestamp:
|
| 65 |
+
print(f"Validation failed: Last timestamp {last:.2f}s is less than {min_last_timestamp}s")
|
| 66 |
+
return False
|
| 67 |
+
if last - second_last > max_interval:
|
| 68 |
+
print(f"Validation failed: Interval between last two timestamps {last - second_last:.2f}s is greater than {max_interval}s")
|
| 69 |
+
return False
|
| 70 |
+
return True
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def chat_gpt(text, model='gpt-4o-mini'):
|
| 74 |
+
while True:
|
| 75 |
+
try:
|
| 76 |
+
# Call OpenAI chat completions API
|
| 77 |
+
completion = client.chat.completions.create(
|
| 78 |
+
model=model, # Use GPT-4o-mini model
|
| 79 |
+
messages=[
|
| 80 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 81 |
+
{"role": "user", "content": text}
|
| 82 |
+
]
|
| 83 |
+
)
|
| 84 |
+
# Get response content
|
| 85 |
+
if getattr(completion.choices[0].message, 'content', None):
|
| 86 |
+
content = completion.choices[0].message.content.strip()
|
| 87 |
+
return content
|
| 88 |
+
else:
|
| 89 |
+
print('error_wait_2s')
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"Error: {e}")
|
| 92 |
+
time.sleep(2)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def chat_gpt_call(text, model='gpt-4o-mini'):
|
| 96 |
+
# Call OpenAI chat completions API
|
| 97 |
+
completion = client.chat.completions.create(
|
| 98 |
+
model=model, # Use GPT-4o-mini model
|
| 99 |
+
messages=[
|
| 100 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 101 |
+
{"role": "user", "content": text}
|
| 102 |
+
]
|
| 103 |
+
)
|
| 104 |
+
# Get response content
|
| 105 |
+
if getattr(completion.choices[0].message, 'content', None):
|
| 106 |
+
content = completion.choices[0].message.content.strip()
|
| 107 |
+
return content
|
| 108 |
+
else:
|
| 109 |
+
print('error_wait_2s')
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def generate_music_descriptions(all_music_data, index_pool, output_file, file_lock, sample_size=20, model='gpt-4o-mini', max_retries=0):
|
| 113 |
+
"""
|
| 114 |
+
Read music data file, randomly sample, call GPT to generate new music descriptions and lyrics
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
all_music_data: List of all music data
|
| 118 |
+
index_pool: Index pool object (thread-safe)
|
| 119 |
+
output_file: Path to output jsonl file
|
| 120 |
+
file_lock: File write lock
|
| 121 |
+
sample_size: Number of samples to randomly extract
|
| 122 |
+
model: Model name to use
|
| 123 |
+
max_retries: Maximum retry count
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
(used_indices, success_count): List of used indices and number of successful generations
|
| 127 |
+
"""
|
| 128 |
+
# duration_ranges = [
|
| 129 |
+
# ("3.00", "3.15", 50), ("3.15", "3.30", 50), ("3.30", "3.45", 60),
|
| 130 |
+
# ("3.45", "4.00", 60),
|
| 131 |
+
# ("4.00", "4.15", 70), ("4.15", "4.30", 70), ("4.30", "4.45", 70)
|
| 132 |
+
# ]
|
| 133 |
+
duration_ranges = [
|
| 134 |
+
("3.00", "3.15", 60), ("3.15", "3.30", 70), ("3.30", "3.45", 80),("3.45", "4.00", 90),
|
| 135 |
+
("4.00", "4.15", 100), ("4.15", "4.30", 100), ("4.30", "4.45", 100)
|
| 136 |
+
]
|
| 137 |
+
selected_range = random.choice(duration_ranges)
|
| 138 |
+
require_length = selected_range[2]
|
| 139 |
+
|
| 140 |
+
# Directly convert to timestamp format (strictly corresponding to left and right ends of tuple)
|
| 141 |
+
start_timestamp = f"[{selected_range[0].replace('.', ':')}.00]"
|
| 142 |
+
end_timestamp = f"[{selected_range[1].replace('.', ':')}.00]"
|
| 143 |
+
|
| 144 |
+
# Generate duration description
|
| 145 |
+
# Convert seconds to minutes and seconds format
|
| 146 |
+
start_seconds = float(selected_range[0])
|
| 147 |
+
start_minutes = int(start_seconds // 60)
|
| 148 |
+
start_secs = int(start_seconds % 60)
|
| 149 |
+
start_duration = f"{start_minutes}min {start_secs}sec"
|
| 150 |
+
|
| 151 |
+
end_seconds = float(selected_range[1])
|
| 152 |
+
end_minutes = int(end_seconds // 60)
|
| 153 |
+
end_secs = int(end_seconds % 60)
|
| 154 |
+
end_duration = f"{end_minutes}min {end_secs}sec"
|
| 155 |
+
|
| 156 |
+
# Generate random timestamps in examples (randomly generated within time range)
|
| 157 |
+
# Parse time string to minutes and seconds
|
| 158 |
+
start_parts = selected_range[0].split('.')
|
| 159 |
+
end_parts = selected_range[1].split('.')
|
| 160 |
+
|
| 161 |
+
start_minutes = int(start_parts[0])
|
| 162 |
+
start_seconds = int(start_parts[1])
|
| 163 |
+
end_minutes = int(end_parts[0])
|
| 164 |
+
end_seconds = int(end_parts[1])
|
| 165 |
+
|
| 166 |
+
# Convert to total seconds
|
| 167 |
+
start_total_seconds = start_minutes * 60 + start_seconds
|
| 168 |
+
end_total_seconds = end_minutes * 60 + end_seconds
|
| 169 |
+
|
| 170 |
+
# Randomly generate within range
|
| 171 |
+
example1_seconds = random.randint(start_total_seconds, end_total_seconds)
|
| 172 |
+
example2_seconds = random.randint(start_total_seconds, end_total_seconds)
|
| 173 |
+
|
| 174 |
+
example1_minutes = example1_seconds // 60
|
| 175 |
+
example1_secs = example1_seconds % 60
|
| 176 |
+
example2_minutes = example2_seconds // 60
|
| 177 |
+
example2_secs = example2_seconds % 60
|
| 178 |
+
|
| 179 |
+
example1_timestamp = f"[{example1_minutes:02d}:{example1_secs:02d}.00]"
|
| 180 |
+
example2_timestamp = f"[{example2_minutes:02d}:{example2_secs:02d}.00]"
|
| 181 |
+
|
| 182 |
+
# Get sample indices from index pool (thread-safe)
|
| 183 |
+
selected_indices = index_pool.get_indices(sample_size)
|
| 184 |
+
if not selected_indices:
|
| 185 |
+
return [], 0
|
| 186 |
+
|
| 187 |
+
sample_data = [all_music_data[i] for i in selected_indices]
|
| 188 |
+
|
| 189 |
+
# Extract all unique styles
|
| 190 |
+
styles = []
|
| 191 |
+
for data in sample_data:
|
| 192 |
+
style = data.get('style', '')
|
| 193 |
+
if style and style not in styles:
|
| 194 |
+
styles.append(style)
|
| 195 |
+
|
| 196 |
+
styles_text = "、".join(styles)
|
| 197 |
+
|
| 198 |
+
# Build example text - include all sampled data (excluding style)
|
| 199 |
+
examples = []
|
| 200 |
+
for i, data in enumerate(sample_data, 1):
|
| 201 |
+
lyrics_text = " ".join(data.get('lyrics', [])) if isinstance(data.get('lyrics'), list) else data.get('lyrics', '')
|
| 202 |
+
description = data.get('description', '')
|
| 203 |
+
examples.append(f"Example {i}:\ndescription: {description}\nlyrics: {lyrics_text}")
|
| 204 |
+
|
| 205 |
+
examples_text = "\n\n".join(examples)
|
| 206 |
+
|
| 207 |
+
prompt = f"""Generate 2 complete songs. Each song must meet the following hard requirements:
|
| 208 |
+
- Strictly forbidden to generate lyrics with fewer than {require_length} lines!
|
| 209 |
+
- The number of lyric lines for each song must be strictly greater than {require_length}. This is a hard requirement!
|
| 210 |
+
- The timestamp of the final line must be between {start_timestamp} and {end_timestamp}.
|
| 211 |
+
- The two songs must differ in duration and line count; their final timestamps must not be identical.
|
| 212 |
+
- The timestamp interval between adjacent lyric lines must not exceed 10 seconds! Timestamps must be continuous and progress naturally.
|
| 213 |
+
- Awkward gaps like "[03:25.00]in the heart[04:25.00]the last lyric" are strictly forbidden. Do not exceed a 10-second interval.
|
| 214 |
+
- It is strictly forbidden to repeat the entire structure or its sections after one iteration is complete. It is also strictly forbidden to repeat the same lyric line multiple times.
|
| 215 |
+
If any of the above requirements are not met, the generation is considered a failure. Please regenerate.
|
| 216 |
+
Please generate 2 new, diverse music descriptions and LRC format lyrics. The language should be English.
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
Creative Requirements:
|
| 220 |
+
1. Style and Genre must be diverse.
|
| 221 |
+
2. Description Tagging Requirements (Must be strictly followed):
|
| 222 |
+
The description field must use a structured tag format, including the following tags, separated by commas:
|
| 223 |
+
- Music Style tag
|
| 224 |
+
- Music Genre tag
|
| 225 |
+
- Instruments tag
|
| 226 |
+
- Emotional Tone tag
|
| 227 |
+
- Mood/Atmosphere tag
|
| 228 |
+
- Vocal Style and Voice tag, limited to either "male voice" or "female voice", solo performance only.
|
| 229 |
+
Note: Each tag should be concise. Multiple tags of the same category can be separated by a slash (e.g., "Piano/Violin").
|
| 230 |
+
3. Lyric Creativity: The lyrics should have depth and artistry:
|
| 231 |
+
- Themes can cover various aspects such as love, life, society, nature, philosophy, dreams, memories, etc.
|
| 232 |
+
- Use rich literary devices: metaphors, imagery, contrast, parallelism, etc.
|
| 233 |
+
- Express sincere emotions with a focus on rhyme and rhythm.
|
| 234 |
+
- The style can be narrative, lyrical, or stream-of-consciousness.
|
| 235 |
+
4. Lyric Structure and Length Requirements (Must be strictly followed):
|
| 236 |
+
- The lyrics must be organized using the following structure, with section tags annotating each part.
|
| 237 |
+
- The structure must strictly follow this order, for a total of 8 section tags: [Verse 1] → [Pre-Chorus] → [Chorus] → [Verse 2] → [Pre-Chorus] → [Chorus] → [Bridge] → [Chorus (Outro)].
|
| 238 |
+
- A single song can only have these 8 section tags. [Verse 1] and [Verse 2] appear once; [Pre-Chorus] and [Chorus] appear twice; [Bridge] and [Chorus (Outro)] appear once. Do not add or repeat extra section tags.
|
| 239 |
+
- Each section tag (e.g., [Verse 1], [Chorus]) must be on its own line, immediately followed by the LRC format lyrics for that section.
|
| 240 |
+
- Separate sections with a blank line.
|
| 241 |
+
- **Total Line Count Requirement**: The entire song must contain at least {require_length} lines of timestamped lyrics (not including section tags or blank lines).
|
| 242 |
+
5. LRC Format Mandatory Rules (Must be strictly followed):
|
| 243 |
+
- Each line of lyrics must be in the format `[mm:ss.xx]Lyric content`, with no space between the timestamp and the lyrics. The lyric content should be coherent.
|
| 244 |
+
- **Each line must contain only one short phrase of lyrics.** Start a new line when encountering punctuation like commas or periods.
|
| 245 |
+
- **Strictly forbidden to merge multiple sentences or clauses onto the same line.**
|
| 246 |
+
- Timestamps must be distributed naturally. **The first line's timestamp must not be [00:00.00]**. Allow for an instrumental intro (suggestion: start between [00:05.00] and [00:15.00]).
|
| 247 |
+
- Timestamp intervals must be varied: The intervals within each song must be diverse, often using decimal values. Do not use a fixed interval:
|
| 248 |
+
* A single song must contain a variety of different intervals; do not use the same interval for all lines (e.g., not all 4-second gaps).
|
| 249 |
+
* Dynamically adjust intervals based on the emotional intensity and rhythm of the lyrics.
|
| 250 |
+
* The gap between adjacent lines should vary to reflect the musical rhythm.
|
| 251 |
+
- Timestamp allocation should be reasonably inferred based on the song's style, emotion, and rhythm, not mechanically assigned based on lyric length.
|
| 252 |
+
- The length of each lyric line should vary naturally; do not make them all uniform.
|
| 253 |
+
- **The total song duration must be between {start_duration} and {end_duration} (meaning the final line's timestamp must be between {start_timestamp} and {end_timestamp}). This is a hard requirement!**
|
| 254 |
+
6. Lyric Length Requirement: The number of lyric lines in the lyrics field must be greater than {require_length}. If the generated length is too short, please regenerate.
|
| 255 |
+
7. Uniqueness and Originality: Each piece should be unique. Avoid simply repeating the content from examples.
|
| 256 |
+
8. Format Requirements:
|
| 257 |
+
- Directly return a JSON array containing 2 song objects. Each object must have only "description" and "lyrics" fields.
|
| 258 |
+
- `description` field: Must be in tag format, not narrative text.
|
| 259 |
+
- `lyrics` field: A string in LRC format with section tags.
|
| 260 |
+
- Strictly forbidden to insert any extra symbols, markers, comments, or explanatory text within the JSON.
|
| 261 |
+
|
| 262 |
+
LRC Format Example (with section tags):
|
| 263 |
+
[Verse 1]
|
| 264 |
+
[00:08.00]First line of lyrics
|
| 265 |
+
[00:12.50]Second line of lyrics
|
| 266 |
+
[00:17.20]Third line of lyrics
|
| 267 |
+
|
| 268 |
+
[Pre-Chorus]
|
| 269 |
+
[00:22.00]Pre-chorus lyrics
|
| 270 |
+
[00:26.50]Pre-chorus lyrics
|
| 271 |
+
|
| 272 |
+
[Chorus]
|
| 273 |
+
[00:31.00]Chorus lyrics
|
| 274 |
+
[00:35.50]Chorus lyrics
|
| 275 |
+
|
| 276 |
+
Negative Examples (to avoid):
|
| 277 |
+
- Incorrect: [01:30.00](Piano Interlude) - Do not add parenthetical comments after the timestamp.
|
| 278 |
+
- Incorrect: [00:00.00]Starting lyric - The first line cannot start at 00:00.00.
|
| 279 |
+
- Incorrect: [00:05.00]In the familiar field, the sun casts golden rays upon the wheat - Strictly forbidden to place multiple clauses on the same line.
|
| 280 |
+
- Incorrect: [03:00.00] In the light of hope[03:05.50] In the light of hope[03:10.20] In the light of hope -Excessive repetition of the exact same lyric line is strictly forbidden. Lyrical content must show variation.
|
| 281 |
+
Now, please fully unleash your creativity and generate 2 new, complete works of music descriptions and LRC format lyrics.
|
| 282 |
+
Special Reminder: Each song must be complete, not abbreviated or omitted! It must contain the full 8 sections (Verse 1, Pre-Chorus, Chorus, Verse 2, Pre-Chorus, Chorus, Bridge, Chorus Outro) and strictly ensure more than {require_length} lines of lyrics.
|
| 283 |
+
|
| 284 |
+
Directly return in JSON array format:
|
| 285 |
+
[
|
| 286 |
+
{{"description": "...", "lyrics": "..."}},
|
| 287 |
+
{{"description": "...", "lyrics": "..."}}
|
| 288 |
+
]"""
|
| 289 |
+
# Try to generate with retry mechanism
|
| 290 |
+
for attempt in range(max_retries + 1):
|
| 291 |
+
try:
|
| 292 |
+
# Call OpenAI API
|
| 293 |
+
completion = client.chat.completions.create(
|
| 294 |
+
model=model,
|
| 295 |
+
messages=[
|
| 296 |
+
{"role": "system", "content": f"You are a creative music lyricist and composer. Please generate diverse and creative music tag-based descriptions and LRC format lyrics with song structure tags. CRITICAL REQUIREMENTS: 1) Description must be structured tags separated by commas, NOT narrative text. 2) Return ONLY pure, valid JSON format without any extra symbols, markers, or comments. 3) Each song must include structure tags like [Verse 1], [Chorus], [Bridge], etc., followed by LRC format lyrics [mm:ss.xx]lyric_content. 4) MANDATORY: Each song must have MORE than {require_length} lines of lyrics with timestamps. "},
|
| 297 |
+
{"role": "user", "content": prompt}
|
| 298 |
+
],
|
| 299 |
+
n=1,
|
| 300 |
+
temperature=1.0,
|
| 301 |
+
)
|
| 302 |
+
#print(prompt)
|
| 303 |
+
# Extract all responses
|
| 304 |
+
results = []
|
| 305 |
+
filtered_count = 0
|
| 306 |
+
last_content = None
|
| 307 |
+
|
| 308 |
+
for i, choice in enumerate(completion.choices, 1):
|
| 309 |
+
try:
|
| 310 |
+
content = choice.message.content.strip()
|
| 311 |
+
last_content = content
|
| 312 |
+
print(f"\n=== GPT Response {i} ===")
|
| 313 |
+
print(content)
|
| 314 |
+
print("=" * 50)
|
| 315 |
+
# Try to extract JSON content
|
| 316 |
+
if "```json" in content:
|
| 317 |
+
content = content.split("```json")[1].split("```")[0].strip()
|
| 318 |
+
elif "```" in content:
|
| 319 |
+
content = content.split("```")[1].split("```")[0].strip()
|
| 320 |
+
|
| 321 |
+
# Clean trailing commas in JSON (extra commas)
|
| 322 |
+
# Remove commas after last element of object/array
|
| 323 |
+
content = re.sub(r',(\s*[}\]])', r'\1', content)
|
| 324 |
+
|
| 325 |
+
# Parse JSON array
|
| 326 |
+
result_array = json.loads(content)
|
| 327 |
+
|
| 328 |
+
# Ensure it's a list
|
| 329 |
+
if isinstance(result_array, list):
|
| 330 |
+
# Validate each object in array
|
| 331 |
+
for song in result_array:
|
| 332 |
+
if isinstance(song, dict) and 'description' in song and 'lyrics' in song:
|
| 333 |
+
if _validate_timestamps(song.get('lyrics', '')):
|
| 334 |
+
results.append(song)
|
| 335 |
+
else:
|
| 336 |
+
filtered_count += 1
|
| 337 |
+
# If returned a single object (compatibility with old format)
|
| 338 |
+
elif isinstance(result_array, dict) and 'description' in result_array and 'lyrics' in result_array:
|
| 339 |
+
if _validate_timestamps(result_array.get('lyrics', '')):
|
| 340 |
+
results.append(result_array)
|
| 341 |
+
else:
|
| 342 |
+
filtered_count += 1
|
| 343 |
+
|
| 344 |
+
except json.JSONDecodeError:
|
| 345 |
+
continue
|
| 346 |
+
|
| 347 |
+
if filtered_count:
|
| 348 |
+
print(f"Total {filtered_count} songs filtered due to timestamp validation failure")
|
| 349 |
+
|
| 350 |
+
# Print parsing results
|
| 351 |
+
print(f"\nParsing complete, results length: {len(results)}")
|
| 352 |
+
print(f"Results content: {results}")
|
| 353 |
+
print(start_duration, end_duration,example1_timestamp,example2_timestamp,require_length)
|
| 354 |
+
|
| 355 |
+
# If parsed result length is not 2, write model response content to test.txt
|
| 356 |
+
if len(results) != 2:
|
| 357 |
+
print(f"Warning: Parsed result length is not 2, actual is {len(results)}, will write to test.txt")
|
| 358 |
+
with open('test.txt', 'w', encoding='utf-8') as f:
|
| 359 |
+
if last_content is not None:
|
| 360 |
+
f.write(last_content)
|
| 361 |
+
print("Written to test.txt file")
|
| 362 |
+
|
| 363 |
+
# Check if successfully generated 50 songs (10 responses * 5 each)
|
| 364 |
+
if len(results) >= 50:
|
| 365 |
+
# Append save results to file (use lock to ensure thread safety)
|
| 366 |
+
with file_lock:
|
| 367 |
+
with open(output_file, 'a', encoding='utf-8') as f:
|
| 368 |
+
for result in results[:50]: # Only save first 50 songs
|
| 369 |
+
f.write(json.dumps(result, ensure_ascii=False) + '\n')
|
| 370 |
+
|
| 371 |
+
return selected_indices, min(len(results), 50)
|
| 372 |
+
elif attempt < max_retries:
|
| 373 |
+
print(f"Only successfully parsed {len(results)}/50 songs, retrying...")
|
| 374 |
+
time.sleep(2)
|
| 375 |
+
else:
|
| 376 |
+
# Last attempt, save even if not 50 songs
|
| 377 |
+
if len(results) > 0:
|
| 378 |
+
with file_lock:
|
| 379 |
+
with open(output_file, 'a', encoding='utf-8') as f:
|
| 380 |
+
for result in results:
|
| 381 |
+
f.write(json.dumps(result, ensure_ascii=False) + '\n')
|
| 382 |
+
return selected_indices, len(results)
|
| 383 |
+
|
| 384 |
+
except Exception as e:
|
| 385 |
+
if attempt < max_retries:
|
| 386 |
+
print(f"Error occurred during generation: {e}, retrying...")
|
| 387 |
+
time.sleep(2)
|
| 388 |
+
else:
|
| 389 |
+
print(f"Generation failed: {e}")
|
| 390 |
+
return selected_indices, 0
|
| 391 |
+
|
| 392 |
+
return selected_indices, 0
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class IndexPool:
|
| 396 |
+
"""Thread-safe index pool with automatic reset support"""
|
| 397 |
+
|
| 398 |
+
def __init__(self, total_size, selected_file):
|
| 399 |
+
self.total_size = total_size
|
| 400 |
+
self.selected_file = selected_file
|
| 401 |
+
self.lock = threading.Lock()
|
| 402 |
+
self.available_indices = []
|
| 403 |
+
self.selected_indices = set()
|
| 404 |
+
self.reset_count = 0 # Record reset count
|
| 405 |
+
|
| 406 |
+
# Load selected indices from file
|
| 407 |
+
self._load_selected_indices()
|
| 408 |
+
# Initialize available indices
|
| 409 |
+
self._reset_pool()
|
| 410 |
+
|
| 411 |
+
def _load_selected_indices(self):
|
| 412 |
+
"""Load selected indices from file"""
|
| 413 |
+
if os.path.exists(self.selected_file):
|
| 414 |
+
with open(self.selected_file, 'r', encoding='utf-8') as f:
|
| 415 |
+
for line in f:
|
| 416 |
+
self.selected_indices.add(int(line.strip()))
|
| 417 |
+
|
| 418 |
+
def _reset_pool(self):
|
| 419 |
+
"""Reset index pool"""
|
| 420 |
+
# Calculate available indices
|
| 421 |
+
self.available_indices = [i for i in range(self.total_size) if i not in self.selected_indices]
|
| 422 |
+
random.shuffle(self.available_indices) # Shuffle order
|
| 423 |
+
|
| 424 |
+
if len(self.available_indices) == 0:
|
| 425 |
+
# If no available indices, all have been used, reset selected_indices
|
| 426 |
+
self.reset_count += 1
|
| 427 |
+
print(f"\nIndex pool exhausted, resetting pool for the {self.reset_count}th time, re-selecting from {self.total_size} songs")
|
| 428 |
+
self.selected_indices.clear()
|
| 429 |
+
self.available_indices = list(range(self.total_size))
|
| 430 |
+
random.shuffle(self.available_indices)
|
| 431 |
+
|
| 432 |
+
def get_indices(self, count):
|
| 433 |
+
"""
|
| 434 |
+
Thread-safe get specified number of indices
|
| 435 |
+
|
| 436 |
+
Args:
|
| 437 |
+
count: Number of indices needed
|
| 438 |
+
|
| 439 |
+
Returns:
|
| 440 |
+
List of selected indices
|
| 441 |
+
"""
|
| 442 |
+
with self.lock:
|
| 443 |
+
# Check if pool needs to be reset
|
| 444 |
+
if len(self.available_indices) < count:
|
| 445 |
+
self._reset_pool()
|
| 446 |
+
|
| 447 |
+
# Get indices
|
| 448 |
+
selected = self.available_indices[:count]
|
| 449 |
+
self.available_indices = self.available_indices[count:]
|
| 450 |
+
|
| 451 |
+
# Add to selected set
|
| 452 |
+
for idx in selected:
|
| 453 |
+
self.selected_indices.add(idx)
|
| 454 |
+
|
| 455 |
+
# Write to file
|
| 456 |
+
with open(self.selected_file, 'a', encoding='utf-8') as f:
|
| 457 |
+
for idx in selected:
|
| 458 |
+
f.write(f"{idx}\n")
|
| 459 |
+
|
| 460 |
+
return selected
|
| 461 |
+
|
| 462 |
+
def get_stats(self):
|
| 463 |
+
"""Get statistics"""
|
| 464 |
+
with self.lock:
|
| 465 |
+
return {
|
| 466 |
+
'available': len(self.available_indices),
|
| 467 |
+
'selected': len(self.selected_indices),
|
| 468 |
+
'reset_count': self.reset_count
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def batch_generate_music(input_file, output_file, selected_file, total_songs=1000, sample_size=20, model='gpt-4o-mini', num_threads=10):
|
| 473 |
+
"""
|
| 474 |
+
Batch generate music descriptions and lyrics (multi-threaded version)
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
input_file: Path to input jsonl file
|
| 478 |
+
output_file: Path to output jsonl file
|
| 479 |
+
selected_file: Path to file recording selected indices
|
| 480 |
+
total_songs: Total number of songs to generate
|
| 481 |
+
sample_size: Number of samples to extract each time
|
| 482 |
+
model: Model name to use
|
| 483 |
+
num_threads: Number of threads
|
| 484 |
+
"""
|
| 485 |
+
# Load all music data
|
| 486 |
+
print("Loading music data...")
|
| 487 |
+
all_music_data = []
|
| 488 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
| 489 |
+
for line in f:
|
| 490 |
+
data = json.loads(line.strip())
|
| 491 |
+
all_music_data.append(data)
|
| 492 |
+
print(f"Loaded {len(all_music_data)} songs")
|
| 493 |
+
|
| 494 |
+
# Create thread-safe index pool
|
| 495 |
+
index_pool = IndexPool(len(all_music_data), selected_file)
|
| 496 |
+
stats = index_pool.get_stats()
|
| 497 |
+
print(f"Currently selected indices: {stats['selected']}")
|
| 498 |
+
print(f"Currently available indices: {stats['available']}")
|
| 499 |
+
|
| 500 |
+
# Calculate number of calls needed (5 songs per call)
|
| 501 |
+
num_iterations = (total_songs + 1) // 2 # Round up
|
| 502 |
+
print(f"Need to call {num_iterations} times to generate approximately {total_songs} songs (5 per call)")
|
| 503 |
+
print(f"Using {num_threads} threads for parallel processing\n")
|
| 504 |
+
|
| 505 |
+
# Create file write lock
|
| 506 |
+
file_lock = threading.Lock()
|
| 507 |
+
|
| 508 |
+
# Statistics
|
| 509 |
+
total_generated = 0
|
| 510 |
+
generated_lock = threading.Lock()
|
| 511 |
+
|
| 512 |
+
def worker_task(task_id):
|
| 513 |
+
"""Worker thread task"""
|
| 514 |
+
try:
|
| 515 |
+
used_indices, success_count = generate_music_descriptions(
|
| 516 |
+
all_music_data=all_music_data,
|
| 517 |
+
index_pool=index_pool,
|
| 518 |
+
output_file=output_file,
|
| 519 |
+
file_lock=file_lock,
|
| 520 |
+
sample_size=sample_size,
|
| 521 |
+
model=model,
|
| 522 |
+
max_retries=0 # Retry
|
| 523 |
+
)
|
| 524 |
+
return success_count
|
| 525 |
+
except Exception as e:
|
| 526 |
+
print(f"Task {task_id} failed: {e}")
|
| 527 |
+
return 0
|
| 528 |
+
|
| 529 |
+
# Use thread pool and progress bar
|
| 530 |
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
| 531 |
+
# Submit all tasks
|
| 532 |
+
futures = {executor.submit(worker_task, i): i for i in range(num_iterations)}
|
| 533 |
+
|
| 534 |
+
# Use tqdm to show progress
|
| 535 |
+
with tqdm(total=num_iterations, desc="Generation progress", unit="batch") as pbar:
|
| 536 |
+
for future in as_completed(futures):
|
| 537 |
+
success_count = future.result()
|
| 538 |
+
|
| 539 |
+
with generated_lock:
|
| 540 |
+
total_generated += success_count
|
| 541 |
+
|
| 542 |
+
# Get current statistics
|
| 543 |
+
stats = index_pool.get_stats()
|
| 544 |
+
|
| 545 |
+
# Update progress bar
|
| 546 |
+
pbar.set_postfix({
|
| 547 |
+
'Batch': f'{success_count}/5',
|
| 548 |
+
'Total': total_generated,
|
| 549 |
+
'Remaining': stats['available'],
|
| 550 |
+
'Resets': stats['reset_count']
|
| 551 |
+
})
|
| 552 |
+
pbar.update(1)
|
| 553 |
+
|
| 554 |
+
# Final statistics
|
| 555 |
+
stats = index_pool.get_stats()
|
| 556 |
+
print(f"\nGeneration complete!")
|
| 557 |
+
print(f"Total generated: {total_generated} songs")
|
| 558 |
+
print(f"Used {stats['selected']} indices")
|
| 559 |
+
print(f"Remaining available indices: {stats['available']}")
|
| 560 |
+
print(f"Pool reset count: {stats['reset_count']}")
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
if __name__ == '__main__':
|
| 564 |
+
input_file = 'tagged_musics.jsonl'
|
| 565 |
+
output_file = 'generate_en_lrc.jsonl'
|
| 566 |
+
selected_file = 'selected.txt'
|
| 567 |
+
# n=1, max_retries=0, sample 10 songs each time, generate 5 new songs
|
| 568 |
+
batch_generate_music(
|
| 569 |
+
input_file=input_file,
|
| 570 |
+
output_file=output_file,
|
| 571 |
+
selected_file=selected_file,
|
| 572 |
+
total_songs=100,
|
| 573 |
+
sample_size=4,
|
| 574 |
+
model='gpt-4o-mini',
|
| 575 |
+
num_threads=20 # Test with 1 thread first
|
| 576 |
+
)
|
| 577 |
+
# Append to txt file
|
data_pipeline/meta_process/convert_convs.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generate multi-turn dialogue data, each turn contains lyric text and corresponding audio token slices
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import re
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from my_tool import load_jsonl
|
| 11 |
+
|
| 12 |
+
TOKEN_PER_SECOND = 25 # Number of tokens per second of audio
|
| 13 |
+
NUM_ITEMS = 100000 # Process N items first
|
| 14 |
+
|
| 15 |
+
timestamp_pattern = re.compile(r"\[([0-9]{1,2}):([0-9]{1,2})(?:[.:]([0-9]{1,3}))?\]")
|
| 16 |
+
|
| 17 |
+
def _parse_lyric_with_timestamps(lyric: str):
|
| 18 |
+
"""
|
| 19 |
+
Return [(start_time_s, text), ...] sorted by timestamp
|
| 20 |
+
"""
|
| 21 |
+
result = []
|
| 22 |
+
for match in timestamp_pattern.finditer(lyric):
|
| 23 |
+
start_idx = match.end()
|
| 24 |
+
end_idx = lyric.find("[", start_idx)
|
| 25 |
+
text = lyric[start_idx:end_idx].strip() if end_idx != -1 else lyric[start_idx:].strip()
|
| 26 |
+
if not text:
|
| 27 |
+
continue
|
| 28 |
+
minute = int(match.group(1))
|
| 29 |
+
second = int(match.group(2))
|
| 30 |
+
ms = int(match.group(3)) if match.group(3) else 0
|
| 31 |
+
total_seconds = minute * 60 + second + ms / 1000
|
| 32 |
+
result.append((total_seconds, text))
|
| 33 |
+
return result
|
| 34 |
+
|
| 35 |
+
def _load_audio_tokens(pt_file):
|
| 36 |
+
"""
|
| 37 |
+
Load MuCodec encoding of audio
|
| 38 |
+
"""
|
| 39 |
+
audio_ids = torch.load(pt_file, map_location="cpu").squeeze().long()
|
| 40 |
+
return audio_ids
|
| 41 |
+
|
| 42 |
+
def _get_token_slice(audio_tokens, start_s, end_s):
|
| 43 |
+
"""Split encoding by time segment"""
|
| 44 |
+
start_idx = int(start_s * TOKEN_PER_SECOND)
|
| 45 |
+
end_idx = int(end_s * TOKEN_PER_SECOND)
|
| 46 |
+
sliced = audio_tokens[start_idx:end_idx]
|
| 47 |
+
return "[SOA]" + "".join([f"<AUDIO_{i.item()}>" for i in sliced]) + "[EOA]"
|
| 48 |
+
|
| 49 |
+
def _process_item(item, pt_dir:str):
|
| 50 |
+
song_name = item.get("song") or item.get("name")
|
| 51 |
+
song_name = song_name.split('.mp3')[0] # For mucodec, remove extension
|
| 52 |
+
pt_file = os.path.join(pt_dir, f"{song_name}.pt")
|
| 53 |
+
if not os.path.exists(pt_file):
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
audio_tokens = _load_audio_tokens(pt_file)
|
| 57 |
+
tlyric_ = item.get('tlyric', "")
|
| 58 |
+
lyric_ = item.get('lyric', "")
|
| 59 |
+
lyric = tlyric_ if len(tlyric_) > len(lyric_) else lyric_
|
| 60 |
+
lyrics_ts = _parse_lyric_with_timestamps(lyric)
|
| 61 |
+
|
| 62 |
+
if not lyrics_ts:
|
| 63 |
+
# Skip if no lyrics
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
rounds = []
|
| 67 |
+
|
| 68 |
+
# First generate a system message containing song information
|
| 69 |
+
intro_text = (
|
| 70 |
+
f"请生成一首歌曲,歌名为《{item.get('name', '')}》,风格是{item.get('style','')}"
|
| 71 |
+
f",情绪为{item.get('emotion','')},节奏:{item.get('rhythm','')},"
|
| 72 |
+
f"{item.get('description','')},由{item.get('singer','')}演唱,语言:{item.get('lang','')}。"
|
| 73 |
+
f"歌词如下:" + " ".join([text for _, text in lyrics_ts]) + "接下来我会逐句告诉你需要生成歌曲片段的歌词,\n请先生成前奏"
|
| 74 |
+
)
|
| 75 |
+
rounds.append({"role": "user", "content": intro_text})
|
| 76 |
+
rounds.append({"role": "assistant", "content": _get_token_slice(audio_tokens, 0, lyrics_ts[0][0])}) # Intro tokens
|
| 77 |
+
|
| 78 |
+
# Each lyric line corresponds to one round
|
| 79 |
+
for idx, (start_s, text) in enumerate(lyrics_ts[:-1]): ## Last line handled separately
|
| 80 |
+
end_s = lyrics_ts[idx + 1][0] if idx + 1 < len(lyrics_ts) else len(audio_tokens)/TOKEN_PER_SECOND # Last line to end of audio
|
| 81 |
+
rounds.append({"role": "user", "content": text})
|
| 82 |
+
rounds.append({"role": "assistant", "content": _get_token_slice(audio_tokens, start_s, end_s)})
|
| 83 |
+
|
| 84 |
+
# Tail processing logic
|
| 85 |
+
rounds.append({"role": "user", "content": f"请生成歌词{lyrics_ts[-1][1]}以及歌曲结尾"})
|
| 86 |
+
rounds.append({"role": "assistant", "content": _get_token_slice(audio_tokens, lyrics_ts[-1][0], len(audio_tokens)/TOKEN_PER_SECOND)})
|
| 87 |
+
|
| 88 |
+
return rounds
|
| 89 |
+
|
| 90 |
+
# ===== External Interface =====
|
| 91 |
+
|
| 92 |
+
def get_convert_convs(dataset:list[dict], pt_dir:str, save_path:str):
|
| 93 |
+
with open(save_path, "w", encoding="utf-8") as fout:
|
| 94 |
+
for item in tqdm(dataset, desc="Converting convs"):
|
| 95 |
+
rounds = _process_item(item, pt_dir)
|
| 96 |
+
if not rounds:
|
| 97 |
+
continue
|
| 98 |
+
fout.write(json.dumps({"messages": rounds}, ensure_ascii=False) + "\n")
|
data_pipeline/meta_process/convert_lyrics.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import json
|
| 4 |
+
import copy
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from my_tool import dict_sort_print
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from convert_convs import _parse_lyric_with_timestamps
|
| 9 |
+
|
| 10 |
+
# ===== Lyric Parsing =====
|
| 11 |
+
|
| 12 |
+
def _parse_lyrics(text:str) -> dict:
|
| 13 |
+
"""Parse metadata, lyrics and timestamps from lyric information"""
|
| 14 |
+
segs = text.split("\n")
|
| 15 |
+
metadata = {
|
| 16 |
+
"lyrics_meta": {},
|
| 17 |
+
"lyrics": [],
|
| 18 |
+
"lyrics_time": [],
|
| 19 |
+
}
|
| 20 |
+
for seg in segs:
|
| 21 |
+
# Format: [time] metadata / lyrics
|
| 22 |
+
results = _parse_lyric_with_timestamps(seg)
|
| 23 |
+
for time, content in results:
|
| 24 |
+
if ":" in content or ":" in content:
|
| 25 |
+
# Metadata
|
| 26 |
+
pos1 = content.find(":")
|
| 27 |
+
pos2 = content.find(":")
|
| 28 |
+
pos = pos1 if pos1 != -1 else pos2
|
| 29 |
+
key = content[:pos].strip()
|
| 30 |
+
value = content[pos+1:].strip()
|
| 31 |
+
metadata["lyrics_meta"][key] = value
|
| 32 |
+
elif time == "00:00.00":
|
| 33 |
+
# Unstructured metadata at the beginning
|
| 34 |
+
continue
|
| 35 |
+
elif len(metadata['lyrics']) == 0 and "/" in content:
|
| 36 |
+
# Unstructured metadata at the beginning
|
| 37 |
+
continue
|
| 38 |
+
else:
|
| 39 |
+
# Only keep English and space punctuation
|
| 40 |
+
if len(content) == 0:
|
| 41 |
+
# Middle gap/end
|
| 42 |
+
if len(metadata['lyrics']) != 0 and metadata['lyrics'][-1] != "<nop>":
|
| 43 |
+
# If there's no previous segment (beginning), or previous segment is empty, don't record (merge)
|
| 44 |
+
metadata['lyrics'].append("<nop>")
|
| 45 |
+
metadata['lyrics_time'].append(time)
|
| 46 |
+
else:
|
| 47 |
+
if len(metadata['lyrics_time']) != 0 and metadata['lyrics_time'][-1] == time and time != "<nop>":
|
| 48 |
+
# Same timestamp means it's a translation (don't record)
|
| 49 |
+
continue
|
| 50 |
+
# Actual lyrics
|
| 51 |
+
metadata['lyrics'].append(content)
|
| 52 |
+
metadata['lyrics_time'].append(time)
|
| 53 |
+
return metadata
|
| 54 |
+
|
| 55 |
+
# ===== Language Detection =====
|
| 56 |
+
|
| 57 |
+
def _count_ch_nan(text:str):
|
| 58 |
+
"""Count the number of Chinese and other non-English characters in a string"""
|
| 59 |
+
ch_num = 0
|
| 60 |
+
nan_num = 0
|
| 61 |
+
nan = ""
|
| 62 |
+
for c in text:
|
| 63 |
+
if '\u4e00' <= c <= '\u9fff':
|
| 64 |
+
ch_num += 1
|
| 65 |
+
elif ('a' <= c <= 'z') or ('A' <= c <= 'Z') or len(c.strip()) == 0:
|
| 66 |
+
continue
|
| 67 |
+
else:
|
| 68 |
+
nan_num += 1
|
| 69 |
+
nan += c
|
| 70 |
+
# if len(nan) > 0:
|
| 71 |
+
# print(nan)
|
| 72 |
+
return ch_num, nan_num
|
| 73 |
+
|
| 74 |
+
def _lang_decide(lyrics:list[str], val_limit:int=5, word_limit=3) -> str:
|
| 75 |
+
"""
|
| 76 |
+
Determine the language type of lyrics (en/zh/ez/instrument/nan)
|
| 77 |
+
- val_limit: Only count if there are at least this many sentences
|
| 78 |
+
- word_limit: Only count if a sentence has at least this many words
|
| 79 |
+
"""
|
| 80 |
+
ch_lyrics = 0
|
| 81 |
+
en_lyrics = 0
|
| 82 |
+
nan_lyrics = 0
|
| 83 |
+
for lyric in lyrics:
|
| 84 |
+
lyric = copy.deepcopy(lyric)
|
| 85 |
+
if lyric.strip() == "<nop>":
|
| 86 |
+
continue
|
| 87 |
+
lyric = re.sub(r"[''¥·′´(),。?""!@#$%^&*()?.'/,=+_—— !…《》<>0-9~※~;-・\"、☆|△【】#「」‖{}\[\]-]", " ", lyric)
|
| 88 |
+
ch_num, nan_num = _count_ch_nan(lyric)
|
| 89 |
+
|
| 90 |
+
if nan_num > word_limit:
|
| 91 |
+
nan_lyrics += 1
|
| 92 |
+
continue
|
| 93 |
+
elif ch_num > word_limit:
|
| 94 |
+
ch_lyrics += 1
|
| 95 |
+
|
| 96 |
+
lyric = re.sub(r'[\u4e00-\u9fff]+', '', lyric)
|
| 97 |
+
# Count English words by space separation
|
| 98 |
+
en_num = len(lyric.split(" "))
|
| 99 |
+
if en_num > word_limit:
|
| 100 |
+
en_lyrics += 1
|
| 101 |
+
|
| 102 |
+
if nan_lyrics > val_limit:
|
| 103 |
+
return "nan"
|
| 104 |
+
if ch_lyrics > val_limit and en_lyrics > val_limit:
|
| 105 |
+
return "ez"
|
| 106 |
+
if ch_lyrics > val_limit:
|
| 107 |
+
return "zh"
|
| 108 |
+
if en_lyrics > val_limit:
|
| 109 |
+
return "en"
|
| 110 |
+
return "instrument"
|
| 111 |
+
|
| 112 |
+
# ===== External Interface =====
|
| 113 |
+
|
| 114 |
+
def get_convert_lyrics(dataset:list[dict], save_path:str, dir:str, src_subfix:str=""):
|
| 115 |
+
"""Convert lyrics and annotate language type (need to locate corresponding song)"""
|
| 116 |
+
new_dataset = []
|
| 117 |
+
lang_count = defaultdict(int)
|
| 118 |
+
unmatch = []
|
| 119 |
+
with open(save_path, 'w', encoding='utf-8') as file:
|
| 120 |
+
for ele in tqdm(dataset, desc="Converting Lyrics"):
|
| 121 |
+
ele = copy.deepcopy(ele)
|
| 122 |
+
# Skip if no lyrics
|
| 123 |
+
if not ele['has_lyric']:
|
| 124 |
+
# Don't add to final result
|
| 125 |
+
continue
|
| 126 |
+
# Get lyrics
|
| 127 |
+
lyric = ele['lyric']
|
| 128 |
+
if lyric == "":
|
| 129 |
+
lyric = ele['tlyric']
|
| 130 |
+
|
| 131 |
+
# Parse lyrics
|
| 132 |
+
new_data = _parse_lyrics(lyric)
|
| 133 |
+
|
| 134 |
+
# Language detection
|
| 135 |
+
lang = _lang_decide(new_data['lyrics'])
|
| 136 |
+
lang_count[lang] += 1
|
| 137 |
+
|
| 138 |
+
# Remove redundant fields
|
| 139 |
+
del ele['artists']
|
| 140 |
+
del ele['lyric']
|
| 141 |
+
del ele['tlyric']
|
| 142 |
+
del ele['has_lyric']
|
| 143 |
+
|
| 144 |
+
# Add new fields
|
| 145 |
+
ele['lyric_lang'] = lang
|
| 146 |
+
ele['source'] += src_subfix
|
| 147 |
+
for key, value in new_data.items():
|
| 148 |
+
ele[key] = value
|
| 149 |
+
|
| 150 |
+
new_dataset.append(ele)
|
| 151 |
+
json.dump(ele, file, ensure_ascii=False)
|
| 152 |
+
file.write("\n")
|
| 153 |
+
|
| 154 |
+
dict_sort_print(lang_count)
|
| 155 |
+
return new_dataset, unmatch
|
| 156 |
+
|
| 157 |
+
def get_match_music(music_data:list[dict], lyric_data:list[dict]):
|
| 158 |
+
"""Get songs that match or don't match with lyrics"""
|
| 159 |
+
# 1. Build lookup set from songs
|
| 160 |
+
name_map = {}
|
| 161 |
+
for ele in tqdm(lyric_data, desc="Existing Lyrics"):
|
| 162 |
+
name = ele['name']
|
| 163 |
+
name = re.sub(" ", "", name)
|
| 164 |
+
artist = ele['artist']
|
| 165 |
+
complete_name = f"{name} - {artist}.mp3"
|
| 166 |
+
name_map[complete_name] = ele
|
| 167 |
+
|
| 168 |
+
# 2. Iterate through songs to find remaining ones
|
| 169 |
+
matches = []
|
| 170 |
+
unmatches = []
|
| 171 |
+
for ele in tqdm(music_data, desc="Check Matching"):
|
| 172 |
+
path = ele['path']
|
| 173 |
+
name = os.path.basename(path)
|
| 174 |
+
if name not in name_map:
|
| 175 |
+
unmatches.append(ele)
|
| 176 |
+
else:
|
| 177 |
+
meta = name_map[name]
|
| 178 |
+
meta['path'] = path
|
| 179 |
+
matches.append(meta)
|
| 180 |
+
return matches, unmatches
|
data_pipeline/meta_process/convert_messages.py
ADDED
|
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generate multi-turn dialogue data based on segment descriptions (segment-by-segment generation).
|
| 3 |
+
|
| 4 |
+
Input:
|
| 5 |
+
- MuseData/data/meta_suno_cn.jsonl # Base meta (contains src_path, lyrics, lang, tag, etc.)
|
| 6 |
+
- 3_block_80000_cn_desc.jsonl # Segment descriptions, contains startS/endS/text/desc for each section
|
| 7 |
+
- mucodec pt directory: same as PT_DIR_CN in multi_data_suno.py
|
| 8 |
+
|
| 9 |
+
Output:
|
| 10 |
+
- MuseData/sft_dataset_suno_cn.jsonl # Multi-turn dialogues generated segment by segment
|
| 11 |
+
|
| 12 |
+
Dialogue format example (refer to temp1.jsonl):
|
| 13 |
+
- First user: Summary prompt (Chinese), explains "segment-by-segment" + provides [Intro dsec...] description
|
| 14 |
+
- First assistant: Intro tokens (0 ~ first section's startS)
|
| 15 |
+
- Subsequent segments:
|
| 16 |
+
user.content = "[{Section} dsec]{desc}\\n{text}"
|
| 17 |
+
assistant.content = corresponding time segment tokens
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import json
|
| 21 |
+
import os
|
| 22 |
+
import re
|
| 23 |
+
from typing import Dict, List, Tuple, Optional
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
from my_tool import load_jsonl, clean_newlines
|
| 28 |
+
|
| 29 |
+
# Language configuration
|
| 30 |
+
LANG = "en"
|
| 31 |
+
# Path configuration
|
| 32 |
+
META_FILE = f"meta_suno_{LANG}.jsonl"
|
| 33 |
+
# desc folder, read all *.jsonl files in it
|
| 34 |
+
DESC_DIR = "desc"
|
| 35 |
+
PT_DIR = f"suno_mucodec_{LANG}"
|
| 36 |
+
# Output directory (different from meta directory), each desc file generates a set of three files
|
| 37 |
+
OUTPUT_DIR = "outputs"
|
| 38 |
+
OUTPUT_BASENAME = "minus_phonemes"
|
| 39 |
+
|
| 40 |
+
TOKEN_PER_SECOND = 25
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
LOG_FILE = os.path.join(OUTPUT_DIR, "section_mismatch_cn.log") # Place in output directory
|
| 44 |
+
|
| 45 |
+
def _log_warning(msg: str):
|
| 46 |
+
"""Print and save to file"""
|
| 47 |
+
print(msg, end='\n')
|
| 48 |
+
with open(LOG_FILE, 'a', encoding='utf-8') as f:
|
| 49 |
+
f.write(msg + '\n')
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Timestamp parsing regex
|
| 53 |
+
timestamp_pattern = re.compile(
|
| 54 |
+
r"\[([0-9]{1,2}):([0-9]{1,2})(?:[.:]([0-9]{1,3}))?\]"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_pt(pt_file: str) -> torch.Tensor:
|
| 59 |
+
return torch.load(pt_file, map_location="cpu").squeeze().long()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_token_slice(audio: torch.Tensor, start_s: float, end_s: float) -> str:
|
| 63 |
+
if start_s < 0:
|
| 64 |
+
start_s = 0
|
| 65 |
+
if end_s < 0:
|
| 66 |
+
end_s = 0
|
| 67 |
+
s_idx = int(start_s * TOKEN_PER_SECOND)
|
| 68 |
+
e_idx = int(end_s * TOKEN_PER_SECOND)
|
| 69 |
+
s_idx = max(0, min(s_idx, audio.shape[0]))
|
| 70 |
+
e_idx = max(0, min(e_idx, audio.shape[0]))
|
| 71 |
+
if e_idx <= s_idx:
|
| 72 |
+
sliced = []
|
| 73 |
+
else:
|
| 74 |
+
sliced = audio[s_idx:e_idx]
|
| 75 |
+
return "[SOA]" + "".join(f"<AUDIO_{int(i)}>" for i in sliced) + "[EOA]"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def infer_pt_path(src_path: str) -> Optional[str]:
|
| 79 |
+
if not src_path:
|
| 80 |
+
return None
|
| 81 |
+
stem = os.path.splitext(os.path.basename(src_path))[0]
|
| 82 |
+
return os.path.join(PT_DIR, f"{stem}.pt")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def parse_lyric_with_timestamps(lyric: str) -> List[Tuple[float, str]]:
|
| 86 |
+
"""
|
| 87 |
+
Parse [(start_time_s, text), ...] from lyrics with timestamps, sorted by time ascending.
|
| 88 |
+
The returned timestamps come from the lyrics field in the meta file.
|
| 89 |
+
"""
|
| 90 |
+
result: List[Tuple[float, str]] = []
|
| 91 |
+
matches = list(timestamp_pattern.finditer(lyric))
|
| 92 |
+
|
| 93 |
+
for i, match in enumerate(matches):
|
| 94 |
+
start_idx = match.end()
|
| 95 |
+
if i + 1 < len(matches):
|
| 96 |
+
end_idx = matches[i + 1].start()
|
| 97 |
+
else:
|
| 98 |
+
end_idx = len(lyric)
|
| 99 |
+
|
| 100 |
+
text = lyric[start_idx:end_idx].strip()
|
| 101 |
+
minutes = int(match.group(1))
|
| 102 |
+
seconds = int(match.group(2))
|
| 103 |
+
ms_str = match.group(3) if match.group(3) else "0"
|
| 104 |
+
|
| 105 |
+
if len(ms_str) == 2:
|
| 106 |
+
fractional_seconds = int(ms_str) / 100.0
|
| 107 |
+
elif len(ms_str) == 3:
|
| 108 |
+
fractional_seconds = int(ms_str) / 1000.0
|
| 109 |
+
else:
|
| 110 |
+
fractional_seconds = int(ms_str) / 1000.0 if ms_str else 0.0
|
| 111 |
+
|
| 112 |
+
total_seconds = minutes * 60 + seconds + fractional_seconds
|
| 113 |
+
result.append((total_seconds, text))
|
| 114 |
+
return result
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def extract_section_from_text(text: str) -> Optional[str]:
|
| 118 |
+
"""
|
| 119 |
+
Relaxed rule: As long as [] contains English words (≥2 letters), return the entire bracket content as-is.
|
| 120 |
+
"""
|
| 121 |
+
# Match if English words appear in the first []
|
| 122 |
+
m = re.search(r'\[([A-Za-z][A-Za-z0-9\s\-\(\)]*)\]', text)
|
| 123 |
+
if m:
|
| 124 |
+
return m.group(1).strip() # Remove leading and trailing spaces
|
| 125 |
+
return None
|
| 126 |
+
|
| 127 |
+
def format_section_label(sec_name: str) -> str:
|
| 128 |
+
"""Keep original spaces, only trim leading and trailing whitespace."""
|
| 129 |
+
return sec_name.strip()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def normalize_section_name(sec_name: str) -> str:
|
| 133 |
+
"""
|
| 134 |
+
Normalize section name for matching:
|
| 135 |
+
- Remove all spaces
|
| 136 |
+
- Convert to lowercase
|
| 137 |
+
- Remove trailing digits (if any)
|
| 138 |
+
"""
|
| 139 |
+
# Remove all spaces
|
| 140 |
+
normalized = sec_name.replace(" ", "").lower()
|
| 141 |
+
# Remove trailing digits (e.g., "verse1" -> "verse", "chorus1" -> "chorus")
|
| 142 |
+
normalized = re.sub(r"\d+$", "", normalized)
|
| 143 |
+
return normalized
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def clean_desc(desc: str) -> str:
|
| 147 |
+
"""
|
| 148 |
+
Clean desc field:
|
| 149 |
+
1. If starts with [desc], remove it
|
| 150 |
+
2. If both ends are brackets, remove brackets
|
| 151 |
+
"""
|
| 152 |
+
if not desc:
|
| 153 |
+
return desc
|
| 154 |
+
|
| 155 |
+
desc = desc.strip()
|
| 156 |
+
|
| 157 |
+
# If starts with [desc], remove it
|
| 158 |
+
if desc.startswith("[desc]"):
|
| 159 |
+
desc = desc[6:].strip()
|
| 160 |
+
|
| 161 |
+
# If both ends are brackets, remove brackets
|
| 162 |
+
if desc.startswith("[") and desc.endswith("]"):
|
| 163 |
+
desc = desc[1:-1].strip()
|
| 164 |
+
|
| 165 |
+
return desc
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def build_desc_map(desc_path_or_dir: str) -> Dict[Tuple[str, int], List[dict]]:
|
| 169 |
+
"""
|
| 170 |
+
Support passing a single jsonl file or a directory containing multiple jsonl files.
|
| 171 |
+
In directory scenario, files are sorted by name and read sequentially, later records with the same key will overwrite earlier ones.
|
| 172 |
+
"""
|
| 173 |
+
mapping: Dict[Tuple[str, int], List[dict]] = {}
|
| 174 |
+
|
| 175 |
+
paths: List[str] = []
|
| 176 |
+
if os.path.isdir(desc_path_or_dir):
|
| 177 |
+
for name in sorted(os.listdir(desc_path_or_dir)):
|
| 178 |
+
if name.endswith(".jsonl"):
|
| 179 |
+
paths.append(os.path.join(desc_path_or_dir, name))
|
| 180 |
+
else:
|
| 181 |
+
paths.append(desc_path_or_dir)
|
| 182 |
+
|
| 183 |
+
for path in paths:
|
| 184 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 185 |
+
for line in f:
|
| 186 |
+
try:
|
| 187 |
+
obj = json.loads(line)
|
| 188 |
+
except Exception:
|
| 189 |
+
with open("error.txt", 'a', encoding='utf-8') as error_file:
|
| 190 |
+
error_file.write(line + "\n")
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
song_id = obj.get("song_id")
|
| 194 |
+
track_idx = obj.get("track_index", 0)
|
| 195 |
+
|
| 196 |
+
# Change: Put the entire object
|
| 197 |
+
# sections = obj.get("sections", [])
|
| 198 |
+
# mapping[(song_id, track_idx)] = sections
|
| 199 |
+
mapping[(song_id, track_idx)] = obj
|
| 200 |
+
return mapping
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def extract_suffix_from_desc_path(desc_path: str, fallback_idx: int) -> str:
|
| 204 |
+
"""
|
| 205 |
+
Extract suffix from desc filename for output file naming.
|
| 206 |
+
Rule: Try to extract <suffix> from filename pattern "3_block_<suffix>_(cn|en)_desc.jsonl".
|
| 207 |
+
If no match, use fallback_idx (starting from 0) converted to string.
|
| 208 |
+
"""
|
| 209 |
+
fname = os.path.basename(desc_path)
|
| 210 |
+
m = re.search(r"3_block_([^_]+)_(?:cn|en)_desc\.jsonl", fname, re.IGNORECASE)
|
| 211 |
+
if m:
|
| 212 |
+
return m.group(1)
|
| 213 |
+
return str(fallback_idx)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def extract_suffix_num(desc_path: str, fallback_idx: int) -> int:
|
| 217 |
+
"""
|
| 218 |
+
Extract sortable numeric suffix for processing desc files in numeric order.
|
| 219 |
+
Use fallback_idx if parsing fails.
|
| 220 |
+
"""
|
| 221 |
+
suffix = extract_suffix_from_desc_path(desc_path, fallback_idx)
|
| 222 |
+
try:
|
| 223 |
+
return int(suffix)
|
| 224 |
+
except ValueError:
|
| 225 |
+
return fallback_idx
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def build_messages(item: dict, obj: dict, audio: torch.Tensor) -> Optional[dict]:
|
| 229 |
+
"""
|
| 230 |
+
Use timestamps from meta to split audio, desc_sections provide desc and text.
|
| 231 |
+
"""
|
| 232 |
+
if not obj:
|
| 233 |
+
return None
|
| 234 |
+
desc_sections = obj.get("sections", [])
|
| 235 |
+
|
| 236 |
+
# Parse timestamps from meta's lyrics
|
| 237 |
+
lyrics_raw = item.get("lyrics", "") or ""
|
| 238 |
+
meta_ts_list = parse_lyric_with_timestamps(lyrics_raw)
|
| 239 |
+
if not meta_ts_list:
|
| 240 |
+
return None
|
| 241 |
+
|
| 242 |
+
total_seconds = audio.shape[0] / float(TOKEN_PER_SECOND)
|
| 243 |
+
|
| 244 |
+
# Sort desc_sections (by startS, for matching)
|
| 245 |
+
desc_sections = sorted(desc_sections, key=lambda x: x.get("startS", 0.0))
|
| 246 |
+
|
| 247 |
+
# Get middle sections from desc_sections (skip first Intro and last Outro)
|
| 248 |
+
# Intro and Outro are not in meta's lyrics, need separate handling
|
| 249 |
+
# Middle sections are matched in order
|
| 250 |
+
middle_desc_sections = desc_sections[1:-1] if len(desc_sections) > 2 else desc_sections[1:] if len(desc_sections) > 1 else []
|
| 251 |
+
|
| 252 |
+
# Identify sections from meta timestamps and build mapping
|
| 253 |
+
# One section may contain multiple lyric lines, need to merge
|
| 254 |
+
section_timestamps: List[Tuple[str, float, float, str, str]] = [] # (section_name, start_s, end_s, text, desc)
|
| 255 |
+
current_section: Optional[Tuple[str, float, str]] = None # (section_name, start_s, accumulated_text)
|
| 256 |
+
desc_idx = 0 # For matching desc in order (from middle_desc_sections)
|
| 257 |
+
|
| 258 |
+
for idx, (start_s, text) in enumerate(meta_ts_list):
|
| 259 |
+
# Extract section name
|
| 260 |
+
section_name = extract_section_from_text(text)
|
| 261 |
+
|
| 262 |
+
if section_name:
|
| 263 |
+
# Encountered new section label
|
| 264 |
+
# First save previous section (if any)
|
| 265 |
+
if current_section:
|
| 266 |
+
# Determine end time of previous section (current timestamp)
|
| 267 |
+
prev_sec_name, prev_start_s, prev_text = current_section
|
| 268 |
+
# Only remove timestamp, keep all other content (including line breaks, section labels, etc.)
|
| 269 |
+
clean_prev_text = re.sub(r"\[([0-9]{1,2}):([0-9]{1,2})(?:[.:]([0-9]{1,3}))?\]", "", prev_text)
|
| 270 |
+
|
| 271 |
+
# Get desc in order (match from middle sections in order)
|
| 272 |
+
prev_desc = ""
|
| 273 |
+
if desc_idx < len(middle_desc_sections):
|
| 274 |
+
prev_desc = clean_desc(middle_desc_sections[desc_idx].get("desc", ""))
|
| 275 |
+
desc_idx += 1
|
| 276 |
+
|
| 277 |
+
section_timestamps.append((prev_sec_name, prev_start_s, start_s, clean_prev_text, prev_desc))
|
| 278 |
+
|
| 279 |
+
# Start new section
|
| 280 |
+
current_section = (section_name, start_s, text)
|
| 281 |
+
else:
|
| 282 |
+
# No section label, belongs to subsequent lines of current section
|
| 283 |
+
if current_section:
|
| 284 |
+
sec_name, sec_start, sec_text = current_section
|
| 285 |
+
# Preserve line breaks, connect with line breaks
|
| 286 |
+
current_section = (sec_name, sec_start, sec_text + "\n" + text)
|
| 287 |
+
# If no current_section, skip (might be empty line before Intro)
|
| 288 |
+
|
| 289 |
+
# Process last section
|
| 290 |
+
# Check if last timestamp is empty text (indicates end marker)
|
| 291 |
+
outro_start_s: Optional[float] = None
|
| 292 |
+
if meta_ts_list and not meta_ts_list[-1][1].strip():
|
| 293 |
+
# Last timestamp is empty text, indicates end marker
|
| 294 |
+
outro_start_s = meta_ts_list[-1][0]
|
| 295 |
+
|
| 296 |
+
if current_section:
|
| 297 |
+
sec_name, sec_start, sec_text = current_section
|
| 298 |
+
# If last timestamp is empty text, last section's end time should be this timestamp
|
| 299 |
+
# Otherwise use total duration
|
| 300 |
+
if outro_start_s is not None:
|
| 301 |
+
end_s = outro_start_s
|
| 302 |
+
else:
|
| 303 |
+
end_s = total_seconds
|
| 304 |
+
|
| 305 |
+
# Only remove timestamp, keep all other content
|
| 306 |
+
clean_text = re.sub(r"\[([0-9]{1,2}):([0-9]{1,2})(?:[.:]([0-9]{1,3}))?\]", "", sec_text)
|
| 307 |
+
|
| 308 |
+
# Get desc in order (match from middle sections in order)
|
| 309 |
+
desc = ""
|
| 310 |
+
if desc_idx < len(middle_desc_sections):
|
| 311 |
+
desc = clean_desc(middle_desc_sections[desc_idx].get("desc", ""))
|
| 312 |
+
desc_idx += 1
|
| 313 |
+
|
| 314 |
+
section_timestamps.append((sec_name, sec_start, end_s, clean_text, desc))
|
| 315 |
+
|
| 316 |
+
# Check if counts match (only check middle sections, excluding Intro and Outro)
|
| 317 |
+
if desc_idx != len(middle_desc_sections):
|
| 318 |
+
_log_warning(f"⚠️ Warning: Section count mismatch! meta has {len(section_timestamps)} sections, desc has {len(middle_desc_sections)} middle sections (excluding Intro and Outro) (song_id: {item.get('song_id')}, track_index: {item.get('track_index')})")
|
| 319 |
+
|
| 320 |
+
if not section_timestamps:
|
| 321 |
+
return None
|
| 322 |
+
|
| 323 |
+
# Intro segment: from 0 to first section's start time
|
| 324 |
+
# Intro's desc should have been obtained in sequential matching, but Intro itself is not in meta's lyrics
|
| 325 |
+
# So need to get from desc_sections' first section (usually Intro)
|
| 326 |
+
first_section_start = section_timestamps[0][1] if section_timestamps else total_seconds
|
| 327 |
+
intro_desc = ""
|
| 328 |
+
if desc_sections and desc_sections[0].get("section", "").lower() == "intro":
|
| 329 |
+
intro_desc = clean_desc(desc_sections[0].get("desc", ""))
|
| 330 |
+
|
| 331 |
+
# Change: Use desc tag
|
| 332 |
+
# tag = item.get("tag", "")
|
| 333 |
+
song_id:str = obj.get("song_id", "")
|
| 334 |
+
omni_tag = obj.get("omni", "")
|
| 335 |
+
style_tag = obj.get("style", "")
|
| 336 |
+
|
| 337 |
+
if song_id.find("cn") != -1:
|
| 338 |
+
# Chinese songs use omni directly
|
| 339 |
+
tag = omni_tag
|
| 340 |
+
else:
|
| 341 |
+
# English songs compare omni / style
|
| 342 |
+
style_sim = obj.get("style_sim", 0)
|
| 343 |
+
omni_sim = obj.get("omni_sim", 0)
|
| 344 |
+
try:
|
| 345 |
+
tag = omni_tag if omni_sim > style_sim else style_tag
|
| 346 |
+
except Exception as e:
|
| 347 |
+
# If sim score is invalid, default to omni
|
| 348 |
+
tag = omni_tag
|
| 349 |
+
print(f"Error: {song_id}, {e}")
|
| 350 |
+
|
| 351 |
+
# Change: English
|
| 352 |
+
intro_prompt = (
|
| 353 |
+
f"Please generate a song in the following style:{tag}.\n"
|
| 354 |
+
"Next, I will tell you the requirements and lyrics for the song fragment to be generated, section by section.\n"
|
| 355 |
+
f"[Intro][desc:{intro_desc}]"
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
messages: List[dict] = []
|
| 359 |
+
messages.append({"role": "user", "content": intro_prompt})
|
| 360 |
+
messages.append(
|
| 361 |
+
{
|
| 362 |
+
"role": "assistant",
|
| 363 |
+
"content": get_token_slice(audio, 0.0, first_section_start),
|
| 364 |
+
}
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# Process segment by segment (using timestamps from meta)
|
| 368 |
+
for idx, (sec_name, start_s, end_s, text, desc) in enumerate(section_timestamps):
|
| 369 |
+
# user content: [Section dsec : desc][Section lyrics : ...] (preserve original spaces)
|
| 370 |
+
label = format_section_label(sec_name)
|
| 371 |
+
content = f"[{label}]"
|
| 372 |
+
content += f"[desc:{desc}]"
|
| 373 |
+
lyrics_text = re.sub(r'^\[.*?\]\s*\n?', '', text.strip())
|
| 374 |
+
if lyrics_text:
|
| 375 |
+
content += f"[lyrics:\n{lyrics_text}]"
|
| 376 |
+
messages.append({"role": "user", "content": content})
|
| 377 |
+
messages.append(
|
| 378 |
+
{
|
| 379 |
+
"role": "assistant",
|
| 380 |
+
"content": get_token_slice(audio, start_s, end_s),
|
| 381 |
+
}
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# If last timestamp is empty text, add Outro segment
|
| 385 |
+
# Outro's desc should be obtained from desc_sections' last section (usually Outro)
|
| 386 |
+
if outro_start_s is not None and outro_start_s < total_seconds:
|
| 387 |
+
outro_desc = ""
|
| 388 |
+
if desc_sections and desc_sections[-1].get("section", "").lower() == "outro":
|
| 389 |
+
outro_desc = clean_desc(desc_sections[-1].get("desc", ""))
|
| 390 |
+
|
| 391 |
+
messages.append({"role": "user", "content": f"[Outro][desc:{outro_desc}]"})
|
| 392 |
+
messages.append(
|
| 393 |
+
{
|
| 394 |
+
"role": "assistant",
|
| 395 |
+
"content": get_token_slice(audio, outro_start_s, total_seconds),
|
| 396 |
+
}
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
sample = {
|
| 400 |
+
"song_id": item.get("song_id"),
|
| 401 |
+
"track_index": item.get("track_index"),
|
| 402 |
+
"src_path": item.get("src_path"),
|
| 403 |
+
"tag": item.get("tag"),
|
| 404 |
+
"lang": item.get("lang"),
|
| 405 |
+
"duration": item.get("duration"),
|
| 406 |
+
"messages": messages,
|
| 407 |
+
}
|
| 408 |
+
return sample
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def process_with_desc(desc_path: str, suffix: str) -> None:
|
| 412 |
+
"""
|
| 413 |
+
Generate output files using a single desc file (messages-only / meta-only).
|
| 414 |
+
Does not generate main file.
|
| 415 |
+
"""
|
| 416 |
+
desc_map = build_desc_map(desc_path)
|
| 417 |
+
|
| 418 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 419 |
+
# messages-only naming: remove suffix, only keep block number (e.g., ..._8000.jsonl)
|
| 420 |
+
out_msg = os.path.join(OUTPUT_DIR, f"{OUTPUT_BASENAME}_{suffix}.jsonl")
|
| 421 |
+
out_meta = os.path.join(OUTPUT_DIR, f"{OUTPUT_BASENAME}_{suffix}_meta_only.jsonl")
|
| 422 |
+
|
| 423 |
+
dataset = load_jsonl(META_FILE)
|
| 424 |
+
total = len(dataset)
|
| 425 |
+
kept = 0
|
| 426 |
+
skipped = 0
|
| 427 |
+
|
| 428 |
+
for path in [out_msg, out_meta]:
|
| 429 |
+
if os.path.exists(path):
|
| 430 |
+
# Change: use past
|
| 431 |
+
already_msg = load_jsonl(out_msg)
|
| 432 |
+
already_meta = load_jsonl(out_meta)
|
| 433 |
+
assert len(already_meta) == len(already_msg)
|
| 434 |
+
dataset = dataset[len(already_msg):]
|
| 435 |
+
|
| 436 |
+
with open(out_msg, "a", encoding="utf-8") as fout_msg, \
|
| 437 |
+
open(out_meta, "a", encoding="utf-8") as fout_meta:
|
| 438 |
+
|
| 439 |
+
for item in tqdm(dataset, desc=f"Processing meta_suno_{LANG}.jsonl (desc: {suffix})"):
|
| 440 |
+
key = (item.get("song_id"), item.get("track_index", 0))
|
| 441 |
+
|
| 442 |
+
obj = desc_map.get(key)
|
| 443 |
+
|
| 444 |
+
if not obj:
|
| 445 |
+
skipped += 1
|
| 446 |
+
continue
|
| 447 |
+
|
| 448 |
+
pt_path = infer_pt_path(item.get("src_path", ""))
|
| 449 |
+
if not pt_path or not os.path.exists(pt_path):
|
| 450 |
+
skipped += 1
|
| 451 |
+
continue
|
| 452 |
+
|
| 453 |
+
audio = load_pt(pt_path)
|
| 454 |
+
sample = build_messages(item, obj, audio)
|
| 455 |
+
if not sample:
|
| 456 |
+
skipped += 1
|
| 457 |
+
continue
|
| 458 |
+
|
| 459 |
+
# Write messages-only (remove _messages_only suffix)
|
| 460 |
+
messages_only = {"messages": sample.get("messages", [])}
|
| 461 |
+
fout_msg.write(json.dumps(messages_only, ensure_ascii=False) + "\n")
|
| 462 |
+
# Write meta-only
|
| 463 |
+
meta_only = {k: v for k, v in sample.items() if k != "messages"}
|
| 464 |
+
fout_meta.write(json.dumps(meta_only, ensure_ascii=False) + "\n")
|
| 465 |
+
kept += 1
|
| 466 |
+
|
| 467 |
+
print(f"✅ messages-only: {out_msg}")
|
| 468 |
+
print(f"✅ meta-only: {out_meta}")
|
| 469 |
+
print(f"Total {total}, kept {kept}, skipped {skipped}")
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def convert_train_valid():
|
| 473 |
+
# Collect desc files
|
| 474 |
+
if not os.path.isdir(DESC_DIR):
|
| 475 |
+
print(f"⚠️ DESC_DIR does not exist: {DESC_DIR}")
|
| 476 |
+
return
|
| 477 |
+
|
| 478 |
+
unsorted_files = [
|
| 479 |
+
os.path.join(DESC_DIR, name)
|
| 480 |
+
for name in os.listdir(DESC_DIR)
|
| 481 |
+
if name.endswith(".jsonl")
|
| 482 |
+
]
|
| 483 |
+
|
| 484 |
+
# Sort by extracted numeric suffix, ensure 8000 comes before 16000
|
| 485 |
+
desc_files = sorted(
|
| 486 |
+
unsorted_files,
|
| 487 |
+
key=lambda p: extract_suffix_num(p, 0),
|
| 488 |
+
)
|
| 489 |
+
if not desc_files:
|
| 490 |
+
print(f"⚠️ No desc files found: {DESC_DIR}")
|
| 491 |
+
return
|
| 492 |
+
|
| 493 |
+
# Change
|
| 494 |
+
# for idx, desc_path in enumerate(desc_files):
|
| 495 |
+
# suffix = extract_suffix_from_desc_path(desc_path, idx)
|
| 496 |
+
# process_with_desc(desc_path, suffix)
|
| 497 |
+
|
| 498 |
+
for desc_path in desc_files:
|
| 499 |
+
name = os.path.splitext(os.path.basename(desc_path))[0]
|
| 500 |
+
if name.endswith(LANG):
|
| 501 |
+
process_with_desc(desc_path, name)
|
| 502 |
+
|
| 503 |
+
# assistant needs, but content is empty
|
| 504 |
+
# Then each section inside has 3 descs, you take the one corresponding to the maximum value
|
| 505 |
+
# omni also has two, take the larger one (style is not used)
|
| 506 |
+
|
| 507 |
+
from meta_phonemes import _get_lyrics, _trans_sentences
|
| 508 |
+
|
| 509 |
+
def _form_section(section:dict, en:bool) -> str:
|
| 510 |
+
"""Process inside a section, determine which desc to select"""
|
| 511 |
+
# Segment label
|
| 512 |
+
section_tag = f"[{section['section']}]"
|
| 513 |
+
# Segment description
|
| 514 |
+
descs = [section['desc1'], section['desc2'], section['desc3']]
|
| 515 |
+
sims = [section['desc1_sim'], section['desc2_sim'], section['desc3_sim']]
|
| 516 |
+
max_sim = max(sims)
|
| 517 |
+
max_index = sims.index(max_sim)
|
| 518 |
+
desc:str = descs[max_index]
|
| 519 |
+
|
| 520 |
+
if desc == "音频过短": # "Audio too short" - keep original Chinese as it's part of data processing logic
|
| 521 |
+
desc = "[desc:]"
|
| 522 |
+
else:
|
| 523 |
+
DESC_START = "[desc] "
|
| 524 |
+
if desc.startswith(DESC_START):
|
| 525 |
+
desc = desc[len(DESC_START):] + "]"
|
| 526 |
+
desc = "[desc:" + desc[1:]
|
| 527 |
+
|
| 528 |
+
# Lyrics & phonemes
|
| 529 |
+
text:str = section['text']
|
| 530 |
+
if text.find(']') != -1:
|
| 531 |
+
# Remove preceding segment label
|
| 532 |
+
start = text.rfind(']')
|
| 533 |
+
text = text[start+1:]
|
| 534 |
+
if len(text.strip()) == 0:
|
| 535 |
+
# Opening segment has no lyrics/phonemes
|
| 536 |
+
lyrics = ""
|
| 537 |
+
phonemes = ""
|
| 538 |
+
else:
|
| 539 |
+
if en:
|
| 540 |
+
lyrics = "[lyrics:\n" + clean_newlines(text) + "]"
|
| 541 |
+
else:
|
| 542 |
+
lyrics = "[lyrics:" + text + "]"
|
| 543 |
+
|
| 544 |
+
sentences, lyrics = _get_lyrics(lyrics)
|
| 545 |
+
phonemes = _trans_sentences(sentences)
|
| 546 |
+
return section_tag + desc + lyrics + phonemes
|
| 547 |
+
|
| 548 |
+
def _form_intro(ele:dict) -> str:
|
| 549 |
+
"""Process to get multi-turn dialogue opening"""
|
| 550 |
+
omni1 = ele['omni1']
|
| 551 |
+
omni2 = ele['omni2']
|
| 552 |
+
omni_sim1 = ele['omni1_sim']
|
| 553 |
+
omni_sim2 = ele['omni2_sim']
|
| 554 |
+
tag = omni1 if omni_sim1 > omni_sim2 else omni2
|
| 555 |
+
return (
|
| 556 |
+
f"Please generate a song in the following style:{tag}.\n"
|
| 557 |
+
"Next, I will tell you the requirements and lyrics for the song fragment to be generated, section by section.\n"
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
def convert_test():
|
| 561 |
+
path = "filter.jsonl"
|
| 562 |
+
dataset = load_jsonl(path)
|
| 563 |
+
save_path = "messages.jsonl"
|
| 564 |
+
with open(save_path, 'w', encoding='utf-8') as file:
|
| 565 |
+
for ele in tqdm(dataset, desc=f"Converting {path}"):
|
| 566 |
+
messages = []
|
| 567 |
+
# Segment processing
|
| 568 |
+
sections = ele['sections']
|
| 569 |
+
id:str = ele['song_id']
|
| 570 |
+
english = id.startswith("suno_test_en")
|
| 571 |
+
for section in sections:
|
| 572 |
+
content = _form_section(section, english)
|
| 573 |
+
messages += [
|
| 574 |
+
{
|
| 575 |
+
"role": "user",
|
| 576 |
+
"content": content
|
| 577 |
+
},
|
| 578 |
+
{
|
| 579 |
+
"role": "assistant",
|
| 580 |
+
"content": ""
|
| 581 |
+
}
|
| 582 |
+
]
|
| 583 |
+
# Initial addition
|
| 584 |
+
first_content = messages[0]['content']
|
| 585 |
+
intro = _form_intro(ele)
|
| 586 |
+
messages[0]['content'] = intro + first_content
|
| 587 |
+
|
| 588 |
+
data = {"messages": messages}
|
| 589 |
+
json.dump(data, file, ensure_ascii=False)
|
| 590 |
+
file.write("\n")
|
| 591 |
+
|
| 592 |
+
if __name__ == "__main__":
|
| 593 |
+
convert_test()
|
data_pipeline/meta_process/convert_segments.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from my_tool import path_join, load_json
|
| 5 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 6 |
+
|
| 7 |
+
def _check_label(label:str, max_length:int=30) -> bool:
|
| 8 |
+
"""Check if label is valid (non-empty, not timestamp, not long lyrics)"""
|
| 9 |
+
length = len(label.strip())
|
| 10 |
+
if length == 0:
|
| 11 |
+
# print("Error Label: Empty")
|
| 12 |
+
return False
|
| 13 |
+
if length > max_length:
|
| 14 |
+
# print(f"Error Label: Words - {label}")
|
| 15 |
+
return False
|
| 16 |
+
if label.find(":") != -1 and label.find(".") != -1:
|
| 17 |
+
# Considered as timestamp
|
| 18 |
+
# print(f"Error Label: Timestamp - {label}")
|
| 19 |
+
return False
|
| 20 |
+
return True
|
| 21 |
+
|
| 22 |
+
def _convert_one(path:str):
|
| 23 |
+
"""Segment a song's metadata, remove redundant content"""
|
| 24 |
+
data = load_json(path)
|
| 25 |
+
dir = os.path.dirname(path)
|
| 26 |
+
name = f"{data['song_id']}_{data['track_index']}.mp3"
|
| 27 |
+
path = path_join(dir, name)
|
| 28 |
+
new_data = {
|
| 29 |
+
"path": path,
|
| 30 |
+
"song_id": data['song_id'],
|
| 31 |
+
"segments": []
|
| 32 |
+
}
|
| 33 |
+
words_info = data['timestamped_lyrics']['alignedWords'] # Sentence-by-sentence information
|
| 34 |
+
seg_info = None
|
| 35 |
+
|
| 36 |
+
empty_head = False
|
| 37 |
+
for id, word_info in enumerate(words_info):
|
| 38 |
+
if not word_info['success']:
|
| 39 |
+
continue
|
| 40 |
+
word:str = word_info['word']
|
| 41 |
+
|
| 42 |
+
label = ""
|
| 43 |
+
if word.startswith('['):
|
| 44 |
+
if seg_info is not None:
|
| 45 |
+
new_data['segments'].append(seg_info)
|
| 46 |
+
label_end = word.find(']')
|
| 47 |
+
label = word[1:label_end]
|
| 48 |
+
if not _check_label(label):
|
| 49 |
+
label = ""
|
| 50 |
+
|
| 51 |
+
if label != "":
|
| 52 |
+
seg_info = {
|
| 53 |
+
"start": word_info['startS'],
|
| 54 |
+
"end": 0,
|
| 55 |
+
"label": label,
|
| 56 |
+
"word": word[label_end+2:]
|
| 57 |
+
}
|
| 58 |
+
elif seg_info is not None:
|
| 59 |
+
seg_info['end'] = word_info['endS']
|
| 60 |
+
seg_info['word'] += word
|
| 61 |
+
else:
|
| 62 |
+
empty_head = True
|
| 63 |
+
if seg_info is not None:
|
| 64 |
+
seg_info['end'] = word_info['endS']
|
| 65 |
+
seg_info['word'] += word
|
| 66 |
+
else:
|
| 67 |
+
empty_head = True
|
| 68 |
+
if empty_head:
|
| 69 |
+
# print(f"Empty Head, segment: {len(new_data['segments'])}, path: {path}")
|
| 70 |
+
pass
|
| 71 |
+
return new_data
|
| 72 |
+
|
| 73 |
+
# ===== External Interface =====
|
| 74 |
+
|
| 75 |
+
def get_convert_segments(data_dir:str, save_path:str, max_workers:int=10):
|
| 76 |
+
paths = []
|
| 77 |
+
for name in tqdm(os.listdir(data_dir), desc="Getting the JSON Paths"):
|
| 78 |
+
if name.endswith(".json"):
|
| 79 |
+
path = path_join(data_dir, name)
|
| 80 |
+
paths.append(path)
|
| 81 |
+
|
| 82 |
+
dataset = []
|
| 83 |
+
with open(save_path, 'w', encoding='utf-8') as file:
|
| 84 |
+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
| 85 |
+
futures = [executor.submit(_convert_one, path) for path in paths]
|
| 86 |
+
with tqdm(total=len(futures), desc="Converting Segments") as pbar:
|
| 87 |
+
for future in as_completed(futures):
|
| 88 |
+
result = future.result()
|
| 89 |
+
dataset.append(result)
|
| 90 |
+
json.dump(result, file, ensure_ascii=False)
|
| 91 |
+
file.write("\n")
|
| 92 |
+
pbar.update(1)
|
| 93 |
+
return dataset
|
data_pipeline/meta_process/evaluate_polyphones.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import jieba
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from my_tool import load_jsonl, save_json, load_json
|
| 5 |
+
from pypinyin import pinyin, Style, load_phrases_dict
|
| 6 |
+
from pypinyin_dict.phrase_pinyin_data import cc_cedict
|
| 7 |
+
|
| 8 |
+
cc_cedict.load()
|
| 9 |
+
re_special_pinyin = re.compile(r'^(n|ng|m)$')
|
| 10 |
+
|
| 11 |
+
# Add
|
| 12 |
+
reference = load_json("poly_correct.json")
|
| 13 |
+
load_phrases_dict(reference)
|
| 14 |
+
|
| 15 |
+
def _filter(dataset:list[dict]):
|
| 16 |
+
"""Filter non-polyphone characters in test set"""
|
| 17 |
+
new_dataset = []
|
| 18 |
+
for ele in tqdm(dataset, desc="Filtering"):
|
| 19 |
+
pos = ele['pos']
|
| 20 |
+
sentence = ele['sentence']
|
| 21 |
+
word = sentence[pos]
|
| 22 |
+
phones = pinyin(word, style=Style.NORMAL, heteronym=True)[0]
|
| 23 |
+
if len(phones) > 1:
|
| 24 |
+
new_dataset.append(ele)
|
| 25 |
+
print(f"Filter non polyphone, {len(dataset)} -> {len(new_dataset)}")
|
| 26 |
+
return new_dataset
|
| 27 |
+
|
| 28 |
+
def evaluate_polyphones(dataset:list[dict], save_fail:str):
|
| 29 |
+
"""Check pinyin processing accuracy for polyphones"""
|
| 30 |
+
dataset = _filter(dataset)
|
| 31 |
+
total = len(dataset)
|
| 32 |
+
right = 0
|
| 33 |
+
correct_dic = {}
|
| 34 |
+
for ele in tqdm(dataset):
|
| 35 |
+
pos = ele['pos']
|
| 36 |
+
phone = ele['phone']
|
| 37 |
+
sentence = ele['sentence']
|
| 38 |
+
seg_list = jieba.cut(sentence)
|
| 39 |
+
length = 0
|
| 40 |
+
for seg in seg_list:
|
| 41 |
+
if length <= pos and length + len(seg) > pos:
|
| 42 |
+
delta = pos - length # Position in segment
|
| 43 |
+
break
|
| 44 |
+
length += len(seg)
|
| 45 |
+
pred_phones = pinyin(seg, style=Style.NORMAL)
|
| 46 |
+
pred_phone = pred_phones[delta][0]
|
| 47 |
+
if pred_phone == phone or pred_phone.endswith("v"):
|
| 48 |
+
right += 1
|
| 49 |
+
elif len(pred_phones) > 1:
|
| 50 |
+
# Corrected pronunciation (only meaningful for phrases)
|
| 51 |
+
pred_phones[delta] = [phone]
|
| 52 |
+
correct_dic[seg] = pred_phones
|
| 53 |
+
print(f"Acc: {(right / total):.2f}")
|
| 54 |
+
|
| 55 |
+
origin_dic = load_json(save_fail)
|
| 56 |
+
merge_dic = origin_dic | correct_dic
|
| 57 |
+
save_json(merge_dic, save_fail)
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
path = "polyphones.jsonl"
|
| 61 |
+
dataset = load_jsonl(path)
|
| 62 |
+
evaluate_polyphones(dataset, "poly_correct.json")
|
data_pipeline/meta_process/filter.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 4 |
+
|
| 5 |
+
def filter_lang(dataset:list[dict], langs:list[str]) -> list[dict]:
|
| 6 |
+
"""Filter dataset, only keep items with lang tag matching specified languages"""
|
| 7 |
+
new_dataset = []
|
| 8 |
+
for ele in tqdm(dataset, desc="Filtering Lang"):
|
| 9 |
+
if 'lang' not in ele or ele['lang'] not in langs :
|
| 10 |
+
continue
|
| 11 |
+
new_dataset.append(ele)
|
| 12 |
+
print(f"filter: {len(dataset)} -> {len(new_dataset)}")
|
| 13 |
+
return new_dataset
|
| 14 |
+
|
| 15 |
+
def _check_duration(ele, lower_bound, upper_bound):
|
| 16 |
+
"""Subprocess task: Check if audio duration is within range"""
|
| 17 |
+
duration = librosa.get_duration(filename=ele['path'])
|
| 18 |
+
if lower_bound != -1 and duration < lower_bound:
|
| 19 |
+
return None
|
| 20 |
+
if upper_bound != -1 and duration > upper_bound:
|
| 21 |
+
return None
|
| 22 |
+
return ele
|
| 23 |
+
|
| 24 |
+
def filter_length(dataset:list[dict], lower_bound:int=-1, upper_bound:int=-1, max_worker:int=4) -> list[dict]:
|
| 25 |
+
"""Filter dataset, only keep items with length in [lower_bound, upper_bound], if set to -1 then no limit on that side"""
|
| 26 |
+
new_dataset = []
|
| 27 |
+
with ProcessPoolExecutor(max_workers=max_worker) as executor:
|
| 28 |
+
futures = [
|
| 29 |
+
executor.submit(_check_duration, ele, lower_bound, upper_bound)
|
| 30 |
+
for ele in dataset
|
| 31 |
+
]
|
| 32 |
+
with tqdm(total=len(futures), desc="Filtering Length") as pbar:
|
| 33 |
+
for future in as_completed(futures):
|
| 34 |
+
result = future.result()
|
| 35 |
+
if result is not None:
|
| 36 |
+
new_dataset.append(result)
|
| 37 |
+
pbar.update(1)
|
| 38 |
+
# for ele in tqdm(dataset, desc="Filtering Length"):
|
| 39 |
+
# duration = librosa.get_duration(filename=ele['path'])
|
| 40 |
+
# if lower_bound != -1 and duration < lower_bound:
|
| 41 |
+
# continue
|
| 42 |
+
# if upper_bound != -1 and duration > upper_bound:
|
| 43 |
+
# continue
|
| 44 |
+
# new_dataset.append(ele)
|
| 45 |
+
print(f"filter: {len(dataset)} -> {len(new_dataset)}")
|
| 46 |
+
return new_dataset
|
data_pipeline/meta_process/main.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from my_tool import (
|
| 2 |
+
load_json,
|
| 3 |
+
load_jsonl,
|
| 4 |
+
load_txt,
|
| 5 |
+
save_jsonl,
|
| 6 |
+
format_meta,
|
| 7 |
+
pure_name,
|
| 8 |
+
BASE_DIR,
|
| 9 |
+
compose_analyze,
|
| 10 |
+
get_sample,
|
| 11 |
+
get_field_suno,
|
| 12 |
+
tags_analyze,
|
| 13 |
+
find_json,
|
| 14 |
+
show_dir,
|
| 15 |
+
convert_mp3,
|
| 16 |
+
tar_dir,
|
| 17 |
+
tar_size_check,
|
| 18 |
+
clean_newlines,
|
| 19 |
+
dict_sort_print,
|
| 20 |
+
)
|
| 21 |
+
from meta_lang import load_asr_model, get_lang_meta
|
| 22 |
+
from meta_tags import load_tag_model, get_tags_meta
|
| 23 |
+
from meta_endpoints import get_endpoints_meta
|
| 24 |
+
from meta_phonemes import get_phonemes_meta
|
| 25 |
+
from filter import filter_lang, filter_length
|
| 26 |
+
from convert_convs import get_convert_convs
|
| 27 |
+
from convert_segments import get_convert_segments
|
| 28 |
+
from convert_lyrics import get_convert_lyrics, get_match_music
|
| 29 |
+
|
| 30 |
+
def pipeline():
|
| 31 |
+
import os
|
| 32 |
+
dir = "suno_batch"
|
| 33 |
+
name = pure_name(dir)
|
| 34 |
+
save_dir = BASE_DIR / f"data/{name}"
|
| 35 |
+
|
| 36 |
+
# Initialize paths (only once)
|
| 37 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 38 |
+
raw_path = os.path.join(save_dir, "raw.jsonl")
|
| 39 |
+
if os.path.exists(raw_path):
|
| 40 |
+
dataset = load_jsonl(raw_path)
|
| 41 |
+
else:
|
| 42 |
+
dataset = format_meta(dir)
|
| 43 |
+
save_jsonl(dataset, raw_path)
|
| 44 |
+
|
| 45 |
+
# Length filtering
|
| 46 |
+
dataset = dataset[:1000]
|
| 47 |
+
max_workers = 10
|
| 48 |
+
dataset = filter_length(dataset, 120, 360, max_workers)
|
| 49 |
+
|
| 50 |
+
# Language tagging
|
| 51 |
+
# lang_bs = 8
|
| 52 |
+
# model = load_asr_model(lang_bs)
|
| 53 |
+
# lang_path = os.path.join(save_dir, "meta_lang.jsonl")
|
| 54 |
+
# dataset = get_lang_meta(model, dataset, lang_bs, lang_path)
|
| 55 |
+
|
| 56 |
+
# Language filtering
|
| 57 |
+
# dataset = filter_lang(dataset, ['zh', 'en'])
|
| 58 |
+
|
| 59 |
+
# Style tagging
|
| 60 |
+
tag_bs = 4
|
| 61 |
+
tag_path = os.path.join(save_dir, "meta_tags.jsonl")
|
| 62 |
+
model, processor = load_tag_model()
|
| 63 |
+
prompt_path = BASE_DIR / "prompts/new_tags.md"
|
| 64 |
+
prompt = load_txt(prompt_path)
|
| 65 |
+
get_tags_meta(model, processor, dataset, prompt, tag_bs, tag_path)
|
| 66 |
+
|
| 67 |
+
def repeat(func):
|
| 68 |
+
while True:
|
| 69 |
+
try:
|
| 70 |
+
func()
|
| 71 |
+
break
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"Error: {e}")
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
repeat(pipeline)
|
data_pipeline/meta_process/meta_endpoints.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import webrtcvad
|
| 3 |
+
import collections
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from my_tool import dup_remove
|
| 6 |
+
from pydub import AudioSegment
|
| 7 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 8 |
+
|
| 9 |
+
def _frame_generator(frame_duration_ms, audio, sample_rate):
|
| 10 |
+
"""Split audio into frames"""
|
| 11 |
+
bytes_per_sample = 2
|
| 12 |
+
frame_size = int(sample_rate * frame_duration_ms / 1000.0) * bytes_per_sample
|
| 13 |
+
offset = 0
|
| 14 |
+
timestamp = 0.0
|
| 15 |
+
frame_duration = frame_duration_ms / 1000.0
|
| 16 |
+
while offset + frame_size < len(audio):
|
| 17 |
+
yield audio[offset:offset + frame_size], timestamp
|
| 18 |
+
timestamp += frame_duration
|
| 19 |
+
offset += frame_size
|
| 20 |
+
|
| 21 |
+
def _vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames):
|
| 22 |
+
"""Merge continuous vocal segments based on webrtcvad"""
|
| 23 |
+
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
|
| 24 |
+
ring_buffer = collections.deque(maxlen=num_padding_frames)
|
| 25 |
+
|
| 26 |
+
triggered = False
|
| 27 |
+
speech_segments = []
|
| 28 |
+
|
| 29 |
+
for frame_bytes, timestamp in frames:
|
| 30 |
+
is_speech = vad.is_speech(frame_bytes, sample_rate)
|
| 31 |
+
|
| 32 |
+
if not triggered:
|
| 33 |
+
ring_buffer.append((frame_bytes, timestamp, is_speech))
|
| 34 |
+
num_voiced = len([f for f in ring_buffer if f[2]])
|
| 35 |
+
if num_voiced > 0.9 * ring_buffer.maxlen:
|
| 36 |
+
triggered = True
|
| 37 |
+
start_time = ring_buffer[0][1]
|
| 38 |
+
ring_buffer.clear()
|
| 39 |
+
else:
|
| 40 |
+
ring_buffer.append((frame_bytes, timestamp, is_speech))
|
| 41 |
+
num_unvoiced = len([f for f in ring_buffer if not f[2]])
|
| 42 |
+
if num_unvoiced > 0.9 * ring_buffer.maxlen:
|
| 43 |
+
end_time = timestamp + (frame_duration_ms / 1000.0)
|
| 44 |
+
speech_segments.append((start_time, end_time))
|
| 45 |
+
triggered = False
|
| 46 |
+
ring_buffer.clear()
|
| 47 |
+
|
| 48 |
+
# If still in speech state at the end, close the last segment
|
| 49 |
+
if triggered:
|
| 50 |
+
end_time = timestamp + (frame_duration_ms / 1000.0)
|
| 51 |
+
speech_segments.append((start_time, end_time))
|
| 52 |
+
|
| 53 |
+
return speech_segments
|
| 54 |
+
|
| 55 |
+
def _one_process(path):
|
| 56 |
+
"""Detect vocal segments in an audio"""
|
| 57 |
+
# 1. Compress audio
|
| 58 |
+
audio = AudioSegment.from_file(path)
|
| 59 |
+
audio = audio.set_frame_rate(16000).set_channels(1).set_sample_width(2)
|
| 60 |
+
sample_rate = audio.frame_rate
|
| 61 |
+
audio_data = audio.raw_data
|
| 62 |
+
|
| 63 |
+
# 2. Initialize VAD (0-3, higher value means more likely to be considered as speech)
|
| 64 |
+
vad = webrtcvad.Vad(0)
|
| 65 |
+
|
| 66 |
+
# 3. Generate frames
|
| 67 |
+
frames = list(_frame_generator(30, audio_data, sample_rate))
|
| 68 |
+
|
| 69 |
+
# 4. Detect vocal intervals
|
| 70 |
+
segments = _vad_collector(sample_rate, 30, 300, vad, frames)
|
| 71 |
+
|
| 72 |
+
# If no vocals, set both start and end to -1
|
| 73 |
+
if len(segments) == 0:
|
| 74 |
+
return {
|
| 75 |
+
"start": -1,
|
| 76 |
+
"end": -1,
|
| 77 |
+
"segments": [],
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
return {
|
| 81 |
+
"start": segments[0][0],
|
| 82 |
+
"end": segments[-1][1],
|
| 83 |
+
"segments": segments,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
# ===== External Interface =====
|
| 87 |
+
|
| 88 |
+
def get_endpoints_meta(dataset:list[dict], save_path:str, max_workers:int=4, save_middle:bool=True):
|
| 89 |
+
"""
|
| 90 |
+
Add endpoint labels to each audio in dataset (mainly for separated vocal audio)
|
| 91 |
+
- Requires 'path' field in each data entry in dataset
|
| 92 |
+
- Write fields: endpoints.start/end
|
| 93 |
+
- Write to save_path in real-time
|
| 94 |
+
- save_middle determines whether to record each sentence's endpoints to save.segments field
|
| 95 |
+
"""
|
| 96 |
+
dataset = dup_remove(dataset, save_path, 'path', 'endpoints')
|
| 97 |
+
new_dataset = []
|
| 98 |
+
with open(save_path, 'a', encoding='utf-8') as file:
|
| 99 |
+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
| 100 |
+
futures = {executor.submit(_one_process, ele['path']): ele for ele in dataset}
|
| 101 |
+
for future in tqdm(as_completed(futures), desc="Detecting endpoints"):
|
| 102 |
+
ele = futures[future] # Get original element
|
| 103 |
+
try:
|
| 104 |
+
result = future.result()
|
| 105 |
+
ele['endpoints'] = {
|
| 106 |
+
"start": result['start'],
|
| 107 |
+
"end": result['end']
|
| 108 |
+
}
|
| 109 |
+
if save_middle:
|
| 110 |
+
if "save" not in ele:
|
| 111 |
+
ele['save'] = {}
|
| 112 |
+
ele['save']['segments'] = result['segments']
|
| 113 |
+
new_dataset.append(ele)
|
| 114 |
+
json.dump(ele, file, ensure_ascii=False)
|
| 115 |
+
file.write("\n")
|
| 116 |
+
except Exception:
|
| 117 |
+
pass
|
| 118 |
+
return new_dataset
|
data_pipeline/meta_process/meta_lang.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from funasr import AutoModel
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from my_tool import get_free_gpu, dup_remove
|
| 8 |
+
|
| 9 |
+
# ===== ASR Model (External) =====
|
| 10 |
+
|
| 11 |
+
def load_asr_model(bs:int):
|
| 12 |
+
"""Load lyric recognition model"""
|
| 13 |
+
device = f"cuda:{get_free_gpu()}"
|
| 14 |
+
model = AutoModel(
|
| 15 |
+
model="iic/SenseVoiceSmall",
|
| 16 |
+
trust_remote_code=True,
|
| 17 |
+
vad_model="fsmn-vad",
|
| 18 |
+
vad_kwargs={"max_single_segment_time": 30000},
|
| 19 |
+
device=device,
|
| 20 |
+
batch_size=bs,
|
| 21 |
+
max_batch_size=bs * 2,
|
| 22 |
+
)
|
| 23 |
+
print(f"Using {device}")
|
| 24 |
+
return model
|
| 25 |
+
|
| 26 |
+
# ===== ASR Parsing =====
|
| 27 |
+
|
| 28 |
+
def _struct2lang(text: str) -> str:
|
| 29 |
+
"""Extract language identifier from structured representation"""
|
| 30 |
+
start = text.find("|")
|
| 31 |
+
text = text[start+1:]
|
| 32 |
+
end = text.find("|")
|
| 33 |
+
return text[:end]
|
| 34 |
+
|
| 35 |
+
def _struct2lyrics(text:str) -> str:
|
| 36 |
+
start = text.rfind(">")
|
| 37 |
+
lyric = text[start+1:]
|
| 38 |
+
return lyric
|
| 39 |
+
|
| 40 |
+
def _struct_parse(text: str) -> Tuple[List[str], List[str]]:
|
| 41 |
+
"""Split structured information sentence by sentence and then parse"""
|
| 42 |
+
texts = text.split(" <")
|
| 43 |
+
langs, lyrics = [], []
|
| 44 |
+
for ele in texts:
|
| 45 |
+
langs.append(_struct2lang(ele))
|
| 46 |
+
lyrics.append(_struct2lyrics(ele))
|
| 47 |
+
return langs, lyrics
|
| 48 |
+
|
| 49 |
+
# ===== ASR Processing =====
|
| 50 |
+
|
| 51 |
+
def _batch_asr(model, paths:List[str]) -> List[Tuple[List[str], List[str]]]:
|
| 52 |
+
"""Batch speech recognition"""
|
| 53 |
+
outputs = model.generate(
|
| 54 |
+
input=paths,
|
| 55 |
+
cache=None,
|
| 56 |
+
language="auto",
|
| 57 |
+
use_itn=True,
|
| 58 |
+
batch_size_s=240,
|
| 59 |
+
merge_vad=True,
|
| 60 |
+
merge_length_s=15,
|
| 61 |
+
)
|
| 62 |
+
return [_struct_parse(output['text']) for output in outputs]
|
| 63 |
+
|
| 64 |
+
# ===== Overall Language Detection =====
|
| 65 |
+
|
| 66 |
+
def _lang_decide(lang_lyrics:list[tuple[str, str]], val_limit:int=5, word_limit=5) -> str:
|
| 67 |
+
"""
|
| 68 |
+
Determine language based on sentence recognition information
|
| 69 |
+
- val_limit: Only count if there are at least this many sentences
|
| 70 |
+
- word_limit: Only count if a sentence has at least this many words
|
| 71 |
+
"""
|
| 72 |
+
lang_count = defaultdict(int)
|
| 73 |
+
seg_langs, seg_lyrics = lang_lyrics
|
| 74 |
+
for lang, lyric in zip(seg_langs, seg_lyrics):
|
| 75 |
+
lyric = lyric.strip()
|
| 76 |
+
if lang == "en":
|
| 77 |
+
words_num = len(lyric.split())
|
| 78 |
+
else:
|
| 79 |
+
words_num = len(lyric)
|
| 80 |
+
if words_num >= word_limit:
|
| 81 |
+
lang_count[lang] += 1
|
| 82 |
+
langs = []
|
| 83 |
+
for lang, count in lang_count.items():
|
| 84 |
+
if count >= val_limit:
|
| 85 |
+
langs.append(lang)
|
| 86 |
+
if len(langs) == 0:
|
| 87 |
+
return "pure"
|
| 88 |
+
elif len(langs) == 1:
|
| 89 |
+
return langs[0]
|
| 90 |
+
else:
|
| 91 |
+
return "multi: " + " ".join(langs)
|
| 92 |
+
|
| 93 |
+
# ===== External Interface =====
|
| 94 |
+
|
| 95 |
+
def get_lang_meta(model, dataset:list[dict], bs:int, save_path:str, save_middle:bool=True) -> list[dict]:
|
| 96 |
+
"""
|
| 97 |
+
Perform language recognition on a JSONL dataset
|
| 98 |
+
- Final language tag is saved to lang field, types include zh, en, ja, ko, yue, pure, multi, etc.
|
| 99 |
+
- save_middle determines whether to save intermediate recognition results (sentence languages and lyrics) to save.langs, save.lyrics
|
| 100 |
+
"""
|
| 101 |
+
data_num = len(dataset)
|
| 102 |
+
dataset = dup_remove(dataset, save_path, 'path', 'lang')
|
| 103 |
+
new_dataset = []
|
| 104 |
+
with open(save_path, 'a', encoding='utf-8') as file:
|
| 105 |
+
for i in tqdm(range(0, data_num, bs), desc="Lang detecting"):
|
| 106 |
+
batch = []
|
| 107 |
+
paths = []
|
| 108 |
+
for ele in dataset[i:i+bs]:
|
| 109 |
+
path = ele['path']
|
| 110 |
+
if os.path.exists(path):
|
| 111 |
+
batch.append(ele)
|
| 112 |
+
paths.append(path)
|
| 113 |
+
lang_lyrics_lis = _batch_asr(model, paths)
|
| 114 |
+
langs = [_lang_decide(lang_lyrics) for lang_lyrics in lang_lyrics_lis]
|
| 115 |
+
for ele, (seg_langs, seg_lyrics), lang in zip(batch, lang_lyrics_lis, langs):
|
| 116 |
+
ele['lang'] = lang
|
| 117 |
+
if save_middle:
|
| 118 |
+
if 'save' not in ele:
|
| 119 |
+
ele['save'] = {}
|
| 120 |
+
ele['save']['langs'] = seg_langs
|
| 121 |
+
ele['save']['lyrics'] = seg_lyrics
|
| 122 |
+
new_dataset.append(ele)
|
| 123 |
+
json.dump(ele, file, ensure_ascii=False)
|
| 124 |
+
file.write("\n")
|
| 125 |
+
return new_dataset
|
data_pipeline/meta_process/meta_phonemes.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import json
|
| 4 |
+
import copy
|
| 5 |
+
import jieba
|
| 6 |
+
import string
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from g2p_en import G2p
|
| 9 |
+
from my_tool import BASE_DIR, load_json
|
| 10 |
+
from pypinyin import pinyin, Style, load_phrases_dict
|
| 11 |
+
from pypinyin_dict.phrase_pinyin_data import cc_cedict
|
| 12 |
+
|
| 13 |
+
cc_cedict.load()
|
| 14 |
+
re_special_pinyin = re.compile(r'^(n|ng|m)$')
|
| 15 |
+
reference = load_json("poly_correct.json")
|
| 16 |
+
load_phrases_dict(reference)
|
| 17 |
+
|
| 18 |
+
# ===== Chinese Conversion =====
|
| 19 |
+
|
| 20 |
+
def _split_py(py):
|
| 21 |
+
"""Split pinyin with tone number into initial (sm) and final (ym) parts"""
|
| 22 |
+
tone = py[-1]
|
| 23 |
+
py = py[:-1]
|
| 24 |
+
sm = ""
|
| 25 |
+
ym = ""
|
| 26 |
+
suf_r = ""
|
| 27 |
+
if re_special_pinyin.match(py):
|
| 28 |
+
py = 'e' + py
|
| 29 |
+
if py[-1] == 'r':
|
| 30 |
+
suf_r = 'r'
|
| 31 |
+
py = py[:-1]
|
| 32 |
+
|
| 33 |
+
if len(py) == 0:
|
| 34 |
+
# rx
|
| 35 |
+
return "", suf_r + tone
|
| 36 |
+
|
| 37 |
+
if py == 'zi' or py == 'ci' or py == 'si' or py == 'ri':
|
| 38 |
+
sm = py[:1]
|
| 39 |
+
ym = "ii"
|
| 40 |
+
elif py == 'zhi' or py == 'chi' or py == 'shi':
|
| 41 |
+
sm = py[:2]
|
| 42 |
+
ym = "iii"
|
| 43 |
+
elif py == 'ya' or py == 'yan' or py == 'yang' or py == 'yao' or py == 'ye' or py == 'yong' or py == 'you':
|
| 44 |
+
sm = ""
|
| 45 |
+
ym = 'i' + py[1:]
|
| 46 |
+
elif py == 'yi' or py == 'yin' or py == 'ying':
|
| 47 |
+
sm = ""
|
| 48 |
+
ym = py[1:]
|
| 49 |
+
elif py == 'yu' or py == 'yv' or py == 'yuan' or py == 'yvan' or py == 'yue ' or py == 'yve' or py == 'yun' or py == 'yvn':
|
| 50 |
+
sm = ""
|
| 51 |
+
ym = 'v' + py[2:]
|
| 52 |
+
elif py == 'wu':
|
| 53 |
+
sm = ""
|
| 54 |
+
ym = "u"
|
| 55 |
+
elif py[0] == 'w':
|
| 56 |
+
sm = ""
|
| 57 |
+
ym = "u" + py[1:]
|
| 58 |
+
elif len(py) >= 2 and (py[0] == 'j' or py[0] == 'q' or py[0] == 'x') and py[1] == 'u':
|
| 59 |
+
sm = py[0]
|
| 60 |
+
ym = 'v' + py[2:]
|
| 61 |
+
else:
|
| 62 |
+
seg_pos = re.search('a|e|i|o|u|v', py)
|
| 63 |
+
try:
|
| 64 |
+
sm = py[:seg_pos.start()]
|
| 65 |
+
ym = py[seg_pos.start():]
|
| 66 |
+
if ym == 'ui':
|
| 67 |
+
ym = 'uei'
|
| 68 |
+
elif ym == 'iu':
|
| 69 |
+
ym = 'iou'
|
| 70 |
+
elif ym == 'un':
|
| 71 |
+
ym = 'uen'
|
| 72 |
+
elif ym == 'ue':
|
| 73 |
+
ym = 've'
|
| 74 |
+
except Exception:
|
| 75 |
+
sm = ym = ""
|
| 76 |
+
return sm, ym
|
| 77 |
+
ym += suf_r + tone
|
| 78 |
+
return sm, ym
|
| 79 |
+
|
| 80 |
+
# All Chinese punctuation
|
| 81 |
+
chinese_punctuation_pattern = r'[\u3002\uff0c\uff1f\uff01\uff1b\uff1a\u201c\u201d\u2018\u2019\u300a\u300b\u3008\u3009\u3010\u3011\u300e\u300f\u2014\u2026\u3001\uff08\uff09]'
|
| 82 |
+
|
| 83 |
+
def _has_ch_punc(text):
|
| 84 |
+
match = re.search(chinese_punctuation_pattern, text)
|
| 85 |
+
return match is not None
|
| 86 |
+
|
| 87 |
+
def _has_en_punc(text):
|
| 88 |
+
return text in string.punctuation
|
| 89 |
+
|
| 90 |
+
def _trans_cn(text:str, with_sp=True):
|
| 91 |
+
"""Convert Chinese to phonemes"""
|
| 92 |
+
phonemes = []
|
| 93 |
+
# Word segmentation
|
| 94 |
+
seg_list = jieba.cut(text)
|
| 95 |
+
# Process word by word
|
| 96 |
+
for seg in seg_list:
|
| 97 |
+
# String validity
|
| 98 |
+
if seg.strip() == "": continue
|
| 99 |
+
# seg_tn = tn_chinese(seg)
|
| 100 |
+
# Convert to pinyin (without tone)
|
| 101 |
+
py =[_py[0] for _py in pinyin(seg, style=Style.TONE3, neutral_tone_with_five=True)]
|
| 102 |
+
# Punctuation detection (skip if present)
|
| 103 |
+
if any([_has_ch_punc(_py) for _py in py]) or any([_has_en_punc(_py) for _py in py]):
|
| 104 |
+
continue
|
| 105 |
+
# Split pinyin
|
| 106 |
+
# phonemes += _split_py(_py)
|
| 107 |
+
for _py in py:
|
| 108 |
+
sm, ym = _split_py(_py)
|
| 109 |
+
if sm != "":
|
| 110 |
+
phonemes.append(sm)
|
| 111 |
+
if ym != "":
|
| 112 |
+
phonemes.append(ym)
|
| 113 |
+
if with_sp:
|
| 114 |
+
phonemes += ["sp"]
|
| 115 |
+
return phonemes
|
| 116 |
+
|
| 117 |
+
# ===== English Conversion =====
|
| 118 |
+
|
| 119 |
+
def _read_lexicon(lex_path):
|
| 120 |
+
"""Read English lexicon"""
|
| 121 |
+
lexicon = {}
|
| 122 |
+
with open(lex_path) as f:
|
| 123 |
+
for line in f:
|
| 124 |
+
temp = re.split(r"\s+", line.strip("\n"))
|
| 125 |
+
word = temp[0]
|
| 126 |
+
phones = temp[1:]
|
| 127 |
+
if word.lower() not in lexicon:
|
| 128 |
+
lexicon[word.lower()] = phones
|
| 129 |
+
return lexicon
|
| 130 |
+
|
| 131 |
+
LEX_PATH = BASE_DIR / f"data/ref/lexion.txt"
|
| 132 |
+
lexicon = _read_lexicon(LEX_PATH)
|
| 133 |
+
|
| 134 |
+
g2p = G2p()
|
| 135 |
+
|
| 136 |
+
def _trans_en(word:str, with_sp=True):
|
| 137 |
+
"""Convert English (word) to phonemes"""
|
| 138 |
+
w_lower = word.lower()
|
| 139 |
+
phonemes = []
|
| 140 |
+
if w_lower in lexicon:
|
| 141 |
+
# Use lexicon if available (cannot directly get reference)
|
| 142 |
+
phonemes += lexicon[w_lower]
|
| 143 |
+
else:
|
| 144 |
+
# Use G2P if not in lexicon
|
| 145 |
+
phonemes = g2p(w_lower)
|
| 146 |
+
if not phonemes:
|
| 147 |
+
phonemes = []
|
| 148 |
+
# Add to lexicon
|
| 149 |
+
lexicon[w_lower] = phonemes
|
| 150 |
+
if len(phonemes) > 0 and with_sp:
|
| 151 |
+
phonemes.append("sp")
|
| 152 |
+
return phonemes
|
| 153 |
+
|
| 154 |
+
# ===== Single Sentence Processing =====
|
| 155 |
+
|
| 156 |
+
def _char_lang(c:str) -> int:
|
| 157 |
+
"""
|
| 158 |
+
Check if a character is Chinese, English, or other
|
| 159 |
+
0 - Chinese
|
| 160 |
+
1 - English
|
| 161 |
+
2 - Number
|
| 162 |
+
3 - Other
|
| 163 |
+
"""
|
| 164 |
+
if '\u4e00' <= c <= '\u9fff':
|
| 165 |
+
return 0
|
| 166 |
+
elif ('a' <= c <= 'z') or ('A' <= c <= 'Z'):
|
| 167 |
+
return 1
|
| 168 |
+
elif c.isdigit():
|
| 169 |
+
return 2
|
| 170 |
+
else:
|
| 171 |
+
return 3
|
| 172 |
+
|
| 173 |
+
NUMBER_MAP = {
|
| 174 |
+
"0": "zero",
|
| 175 |
+
"1": "one",
|
| 176 |
+
"2": "two",
|
| 177 |
+
"3": "three",
|
| 178 |
+
"4": "four",
|
| 179 |
+
"5": "five",
|
| 180 |
+
"6": "six",
|
| 181 |
+
"7": "seven",
|
| 182 |
+
"8": "eight",
|
| 183 |
+
"9": "nine",
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
def _lang_seperate(text:str) -> list[str]:
|
| 187 |
+
"""Split string by language"""
|
| 188 |
+
lang_segs = [] # Set of split strings
|
| 189 |
+
lang_tags = [] # Tags for each string segment
|
| 190 |
+
lang_seg = "" # Previous continuous language string
|
| 191 |
+
lang_tag = -1 # Language type of previous character
|
| 192 |
+
en_count = 0
|
| 193 |
+
for c in text:
|
| 194 |
+
lang = _char_lang(c)
|
| 195 |
+
if lang_tag != lang:
|
| 196 |
+
# Different from previous character type
|
| 197 |
+
if lang_seg != "":
|
| 198 |
+
lang_segs.append(lang_seg)
|
| 199 |
+
lang_tags.append(lang_tag)
|
| 200 |
+
if lang_tag == 1:
|
| 201 |
+
en_count += 1
|
| 202 |
+
lang_seg = ""
|
| 203 |
+
if lang == 2 and en_count >= 4:
|
| 204 |
+
# Number conversion in English
|
| 205 |
+
lang_segs.append(NUMBER_MAP[c])
|
| 206 |
+
lang_tags.append(1)
|
| 207 |
+
lang_tag = lang
|
| 208 |
+
if lang < 2:
|
| 209 |
+
lang_seg += c
|
| 210 |
+
if lang_seg != "":
|
| 211 |
+
# Last valid segment
|
| 212 |
+
lang_segs.append(lang_seg)
|
| 213 |
+
lang_tags.append(lang_tag)
|
| 214 |
+
return lang_segs, lang_tags
|
| 215 |
+
|
| 216 |
+
def _phoneme_trans(text:str, with_sp=True):
|
| 217 |
+
"""Convert a lyric segment to phonemes"""
|
| 218 |
+
# Split by language
|
| 219 |
+
lang_segs, lang_tags = _lang_seperate(text)
|
| 220 |
+
# Convert segment by segment
|
| 221 |
+
phonemes = []
|
| 222 |
+
for lang_seg, lang_tag in zip(lang_segs, lang_tags):
|
| 223 |
+
if lang_tag == 0:
|
| 224 |
+
# Chinese
|
| 225 |
+
phonemes += _trans_cn(lang_seg, with_sp)
|
| 226 |
+
else:
|
| 227 |
+
# English
|
| 228 |
+
phonemes += _trans_en(lang_seg, with_sp)
|
| 229 |
+
return phonemes
|
| 230 |
+
|
| 231 |
+
# ===== Dynamic Adaptation =====
|
| 232 |
+
|
| 233 |
+
def _get_lyrics(raw_content:str) -> list[str]:
|
| 234 |
+
"""Extract lyric content from dialogue, format like '[stage][dsec:xxx][lyrics:xxx\nxxx]'"""
|
| 235 |
+
START_FORMAT = "[lyrics:"
|
| 236 |
+
start = raw_content.find(START_FORMAT)
|
| 237 |
+
if start == -1:
|
| 238 |
+
return None, None
|
| 239 |
+
content = raw_content[start+len(START_FORMAT):-1]
|
| 240 |
+
# Filter brackets
|
| 241 |
+
content = re.sub(r'\[.*?\]', '', content) # Complete brackets
|
| 242 |
+
content = re.sub(r'[\[\]]', '', content) # Unclosed brackets
|
| 243 |
+
# Split sentences
|
| 244 |
+
sentences = content.split("\n")
|
| 245 |
+
# Reconstruct
|
| 246 |
+
new_content = raw_content[:start] + START_FORMAT + content + "]"
|
| 247 |
+
return sentences, new_content
|
| 248 |
+
|
| 249 |
+
def _trans_sentences(sentences:list[str], with_sp:bool=True) -> str:
|
| 250 |
+
"""Convert sentence list to wrapped phoneme string"""
|
| 251 |
+
phonemes_lis = []
|
| 252 |
+
for sentence in sentences:
|
| 253 |
+
phonemes = _phoneme_trans(sentence, with_sp)
|
| 254 |
+
phonemes_lis.append(" ".join(phonemes))
|
| 255 |
+
# Wrap
|
| 256 |
+
phonemes_str = '\n'.join(phonemes_lis)
|
| 257 |
+
envelope = f"[phoneme:{phonemes_str}]"
|
| 258 |
+
envelope = re.sub(r'\d+', '', envelope) # Remove tones
|
| 259 |
+
return envelope
|
| 260 |
+
|
| 261 |
+
# ===== External Interface =====
|
| 262 |
+
|
| 263 |
+
def get_phonemes_meta(dataset:list[dict], save_path:str, with_sp:bool=True):
|
| 264 |
+
"""Add phonemes to lyrics in dataset"""
|
| 265 |
+
new_dataset = []
|
| 266 |
+
with open(save_path, 'w', encoding='utf-8') as file:
|
| 267 |
+
for ele in tqdm(dataset, desc="Phoneme trans"):
|
| 268 |
+
ele = copy.deepcopy(ele)
|
| 269 |
+
messages = ele['messages']
|
| 270 |
+
# Skip first message, process subsequent ones sentence by sentence
|
| 271 |
+
for message in messages[1:]:
|
| 272 |
+
if message['role'] == "assistant":
|
| 273 |
+
continue
|
| 274 |
+
content = message['content']
|
| 275 |
+
sentences, new_content = _get_lyrics(content)
|
| 276 |
+
if sentences is None:
|
| 277 |
+
continue
|
| 278 |
+
phonemes = _trans_sentences(sentences, with_sp)
|
| 279 |
+
message['content'] = new_content + phonemes
|
| 280 |
+
new_dataset.append(ele)
|
| 281 |
+
json.dump(ele, file, ensure_ascii=False)
|
| 282 |
+
file.write("\n")
|
| 283 |
+
return new_dataset
|
data_pipeline/meta_process/meta_tags.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from transformers import (
|
| 6 |
+
Qwen3OmniMoeProcessor,
|
| 7 |
+
Qwen3OmniMoeForConditionalGeneration
|
| 8 |
+
)
|
| 9 |
+
from qwen_omni_utils import process_mm_info
|
| 10 |
+
from my_tool import get_free_gpu, audio_cut, extract_json, dup_remove, BASE_DIR
|
| 11 |
+
|
| 12 |
+
# ===== Tag Model and Processor (External) =====
|
| 13 |
+
|
| 14 |
+
def load_tag_model():
|
| 15 |
+
"""Load tag model"""
|
| 16 |
+
device = f"cuda:{get_free_gpu()}"
|
| 17 |
+
print(f"Using {device}")
|
| 18 |
+
model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
|
| 19 |
+
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
|
| 20 |
+
model_name,
|
| 21 |
+
dtype=torch.bfloat16,
|
| 22 |
+
# local_files_only=True
|
| 23 |
+
).to(device)
|
| 24 |
+
model.disable_talker()
|
| 25 |
+
model.eval()
|
| 26 |
+
|
| 27 |
+
processor = Qwen3OmniMoeProcessor.from_pretrained(
|
| 28 |
+
model_name,
|
| 29 |
+
# local_files_only=True
|
| 30 |
+
)
|
| 31 |
+
return model, processor
|
| 32 |
+
|
| 33 |
+
# ===== Tag Annotation =====
|
| 34 |
+
|
| 35 |
+
def _format_messages(prompt:str, path:str) -> list[dict]:
|
| 36 |
+
"""Construct messages to pass to omni"""
|
| 37 |
+
messages = [
|
| 38 |
+
{
|
| 39 |
+
"role": "system",
|
| 40 |
+
"content": [
|
| 41 |
+
{"type": "text", "text": prompt}
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"role": "user",
|
| 46 |
+
"content": [
|
| 47 |
+
{"type": "audio", "audio": path},
|
| 48 |
+
]
|
| 49 |
+
}
|
| 50 |
+
]
|
| 51 |
+
return messages
|
| 52 |
+
|
| 53 |
+
def _batch_tagging(model, processor, paths:list[str], prompt:str, mode="random"):
|
| 54 |
+
"""Annotate a batch of songs"""
|
| 55 |
+
convs = []
|
| 56 |
+
middle_paths = []
|
| 57 |
+
output_dir = BASE_DIR / "data/temp"
|
| 58 |
+
for path in paths:
|
| 59 |
+
seg_path = audio_cut(path, mode, output_dir)
|
| 60 |
+
middle_paths.append(seg_path)
|
| 61 |
+
messages = _format_messages(prompt, seg_path)
|
| 62 |
+
convs.append(messages)
|
| 63 |
+
|
| 64 |
+
USE_AUDIO_IN_VIDEO = False
|
| 65 |
+
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
text = processor.apply_chat_template(convs, add_generation_prompt=True, tokenize=False)
|
| 68 |
+
audios, images, videos = process_mm_info(convs, use_audio_in_video=USE_AUDIO_IN_VIDEO)
|
| 69 |
+
inputs = processor(
|
| 70 |
+
text=text,
|
| 71 |
+
audio=audios,
|
| 72 |
+
padding=True,
|
| 73 |
+
images=images,
|
| 74 |
+
videos=videos,
|
| 75 |
+
return_tensors="pt",
|
| 76 |
+
use_audio_in_video=USE_AUDIO_IN_VIDEO
|
| 77 |
+
)
|
| 78 |
+
inputs = inputs.to(model.device).to(model.dtype)
|
| 79 |
+
|
| 80 |
+
text_ids = model.generate(
|
| 81 |
+
**inputs,
|
| 82 |
+
max_new_tokens=2048,
|
| 83 |
+
return_audio=False,
|
| 84 |
+
thinker_return_dict_in_generate=True,
|
| 85 |
+
use_audio_in_video=USE_AUDIO_IN_VIDEO
|
| 86 |
+
)
|
| 87 |
+
gene_texts = processor.batch_decode(
|
| 88 |
+
text_ids[0].sequences[:, inputs["input_ids"].shape[1] :],
|
| 89 |
+
skip_special_tokens=True,
|
| 90 |
+
clean_up_tokenization_spaces=False
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
torch.cuda.empty_cache()
|
| 94 |
+
# Delete audio segments
|
| 95 |
+
for path in middle_paths:
|
| 96 |
+
if os.path.exists(path):
|
| 97 |
+
os.remove(path)
|
| 98 |
+
return gene_texts
|
| 99 |
+
|
| 100 |
+
# ===== External Interface =====
|
| 101 |
+
|
| 102 |
+
def get_tags_meta(model, processor, dataset:list[dict], prompt:str, bs:int, save_path:str):
|
| 103 |
+
data_num = len(dataset)
|
| 104 |
+
dataset = dup_remove(dataset, save_path, 'path', 'tags')
|
| 105 |
+
new_dataset = []
|
| 106 |
+
with open(save_path, 'a', encoding="utf-8") as file:
|
| 107 |
+
for i in tqdm(range(0, data_num, bs)):
|
| 108 |
+
batch = []
|
| 109 |
+
paths = []
|
| 110 |
+
for ele in dataset[i:i+bs]:
|
| 111 |
+
path = ele['path']
|
| 112 |
+
if os.path.exists(path):
|
| 113 |
+
batch.append(ele)
|
| 114 |
+
paths.append(path)
|
| 115 |
+
contents = _batch_tagging(model, processor, paths, prompt)
|
| 116 |
+
for ele, content in zip(batch, contents):
|
| 117 |
+
ckeck, json_data = extract_json(content)
|
| 118 |
+
if not ckeck:
|
| 119 |
+
continue
|
| 120 |
+
ele['tags'] = json_data['tags']
|
| 121 |
+
new_dataset.append(ele)
|
| 122 |
+
json.dump(ele, file, ensure_ascii=False)
|
| 123 |
+
file.write('\n')
|
| 124 |
+
return new_dataset
|
data_pipeline/meta_process/meta_vocal.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import librosa
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from demucs.pretrained import get_model
|
| 10 |
+
from demucs.apply import apply_model
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ======================
|
| 14 |
+
# Basic Configuration
|
| 15 |
+
# ======================
|
| 16 |
+
|
| 17 |
+
SAMPLE_RATE = 44100
|
| 18 |
+
MAX_DURATION = 5 # Only take first 30 seconds
|
| 19 |
+
BATCH_SIZE = 8 # Can adjust to 16~32 for 80G GPU memory
|
| 20 |
+
VOCAL_DB_THRESHOLD = -35.0 # Vocal presence threshold (empirical value)
|
| 21 |
+
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
DEVICE = "cuda:1"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ======================
|
| 26 |
+
# Audio Loading (first 30s only)
|
| 27 |
+
# ======================
|
| 28 |
+
|
| 29 |
+
def load_audio_30s(path: str, sr: int = SAMPLE_RATE) -> torch.Tensor:
|
| 30 |
+
"""
|
| 31 |
+
Returns shape: [channels=2, samples]
|
| 32 |
+
"""
|
| 33 |
+
y, _ = librosa.load(
|
| 34 |
+
path,
|
| 35 |
+
sr=sr,
|
| 36 |
+
mono=False,
|
| 37 |
+
duration=MAX_DURATION
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
if y.ndim == 1:
|
| 41 |
+
y = np.stack([y, y], axis=0)
|
| 42 |
+
|
| 43 |
+
return torch.from_numpy(y)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ======================
|
| 47 |
+
# dB Calculation
|
| 48 |
+
# ======================
|
| 49 |
+
|
| 50 |
+
def rms_db(wav: torch.Tensor) -> float:
|
| 51 |
+
"""
|
| 52 |
+
wav: [channels, samples]
|
| 53 |
+
"""
|
| 54 |
+
rms = torch.sqrt(torch.mean(wav ** 2))
|
| 55 |
+
db = 20 * torch.log10(rms + 1e-8)
|
| 56 |
+
return db.item()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ======================
|
| 60 |
+
# Demucs Vocal Detection
|
| 61 |
+
# ======================
|
| 62 |
+
|
| 63 |
+
class DemucsVocalDetector:
|
| 64 |
+
def __init__(self):
|
| 65 |
+
self.model = (
|
| 66 |
+
get_model("htdemucs")
|
| 67 |
+
.to(DEVICE)
|
| 68 |
+
.eval()
|
| 69 |
+
)
|
| 70 |
+
self.vocal_idx = self.model.sources.index("vocals")
|
| 71 |
+
|
| 72 |
+
@torch.no_grad()
|
| 73 |
+
def batch_has_vocal(self, audio_paths: List[str]) -> Dict[str, bool]:
|
| 74 |
+
"""
|
| 75 |
+
Input: List of audio paths
|
| 76 |
+
Output: {path: whether has vocals}
|
| 77 |
+
"""
|
| 78 |
+
results = {}
|
| 79 |
+
|
| 80 |
+
batch_wavs = []
|
| 81 |
+
batch_paths = []
|
| 82 |
+
|
| 83 |
+
print("start load")
|
| 84 |
+
for path in audio_paths:
|
| 85 |
+
try:
|
| 86 |
+
wav = load_audio_30s(path)
|
| 87 |
+
batch_wavs.append(wav)
|
| 88 |
+
batch_paths.append(path)
|
| 89 |
+
except Exception as e:
|
| 90 |
+
results[path] = False
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
print("finish load")
|
| 94 |
+
self._process_batch(batch_wavs, batch_paths, results)
|
| 95 |
+
|
| 96 |
+
return results
|
| 97 |
+
|
| 98 |
+
def _process_batch(self, wavs, paths, results):
|
| 99 |
+
max_len = max(w.shape[1] for w in wavs)
|
| 100 |
+
padded = []
|
| 101 |
+
|
| 102 |
+
for w in wavs:
|
| 103 |
+
if w.shape[1] < max_len:
|
| 104 |
+
w = torch.nn.functional.pad(w, (0, max_len - w.shape[1]))
|
| 105 |
+
padded.append(w)
|
| 106 |
+
|
| 107 |
+
batch = torch.stack(padded, dim=0).to(DEVICE)
|
| 108 |
+
|
| 109 |
+
print("start demucs")
|
| 110 |
+
torch.cuda.synchronize()
|
| 111 |
+
|
| 112 |
+
sources = apply_model(
|
| 113 |
+
self.model,
|
| 114 |
+
batch,
|
| 115 |
+
SAMPLE_RATE,
|
| 116 |
+
device=DEVICE,
|
| 117 |
+
split=False, # 🔥 Core
|
| 118 |
+
progress=False
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
torch.cuda.synchronize()
|
| 122 |
+
print("demucs done")
|
| 123 |
+
|
| 124 |
+
vocals = sources[:, self.vocal_idx]
|
| 125 |
+
|
| 126 |
+
for i, path in enumerate(paths):
|
| 127 |
+
db = rms_db(vocals[i])
|
| 128 |
+
results[path] = db > VOCAL_DB_THRESHOLD
|
| 129 |
+
|
| 130 |
+
# ======================
|
| 131 |
+
# Usage Example
|
| 132 |
+
# ======================
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
audio_list = []
|
| 136 |
+
|
| 137 |
+
detector = DemucsVocalDetector()
|
| 138 |
+
result = detector.batch_has_vocal(audio_list)
|
| 139 |
+
|
| 140 |
+
for k, v in result.items():
|
| 141 |
+
print(f"{k}: {'Has' if v else 'No'} vocals")
|
data_pipeline/meta_process/my_tool.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
import tarfile
|
| 6 |
+
import subprocess
|
| 7 |
+
import json_repair
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from pydub import AudioSegment
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
| 13 |
+
|
| 14 |
+
# ===== Macros =====
|
| 15 |
+
|
| 16 |
+
BASE_DIR = Path(__file__).parent
|
| 17 |
+
|
| 18 |
+
# ===== Helper Functions =====
|
| 19 |
+
|
| 20 |
+
def pure_name(path:str):
|
| 21 |
+
"""Get the original name of a file path (without extension)"""
|
| 22 |
+
basename = os.path.basename(path)
|
| 23 |
+
dot_pos = basename.rfind('.')
|
| 24 |
+
if dot_pos == -1:
|
| 25 |
+
return basename
|
| 26 |
+
return basename[:dot_pos]
|
| 27 |
+
|
| 28 |
+
def extract_json(text: str) -> tuple[bool, dict]:
|
| 29 |
+
"""Extract and repair JSON data from text (enhanced error-tolerant version)
|
| 30 |
+
|
| 31 |
+
Features:
|
| 32 |
+
1. Automatically identify code block markers (```json``` or ```)
|
| 33 |
+
2. Fix common JSON errors (mismatched quotes, trailing commas, etc.)
|
| 34 |
+
3. Support lenient parsing mode
|
| 35 |
+
|
| 36 |
+
Returns: (success, parsed dictionary)
|
| 37 |
+
"""
|
| 38 |
+
# Preprocessing: extract possible JSON content area
|
| 39 |
+
content = text
|
| 40 |
+
|
| 41 |
+
# Case 1: Check ```json``` code block
|
| 42 |
+
if '```json' in text:
|
| 43 |
+
start = text.find('```json')
|
| 44 |
+
end = text.find('```', start + 6)
|
| 45 |
+
content = text[start + 6:end].strip()
|
| 46 |
+
# Case 2: Check regular ``` code block
|
| 47 |
+
elif '```' in text:
|
| 48 |
+
start = text.find('```')
|
| 49 |
+
end = text.find('```', start + 3)
|
| 50 |
+
content = text[start + 3:end].strip()
|
| 51 |
+
|
| 52 |
+
# Clean common interference items in content
|
| 53 |
+
content = re.sub(r'^[^{[]*', '', content) # Remove unstructured content before JSON
|
| 54 |
+
content = re.sub(r'[^}\]]*$', '', content) # Remove unstructured content after JSON
|
| 55 |
+
|
| 56 |
+
# Try standard parsing
|
| 57 |
+
try:
|
| 58 |
+
json_data = json.loads(content)
|
| 59 |
+
return True, json_data
|
| 60 |
+
except json.JSONDecodeError as e:
|
| 61 |
+
standard_error = e
|
| 62 |
+
|
| 63 |
+
# Try to repair with json_repair
|
| 64 |
+
try:
|
| 65 |
+
repaired = json_repair.repair_json(content)
|
| 66 |
+
json_data = json.loads(repaired)
|
| 67 |
+
return True, json_data
|
| 68 |
+
except Exception as e:
|
| 69 |
+
repair_error = e
|
| 70 |
+
return False, {
|
| 71 |
+
"standard_error": standard_error,
|
| 72 |
+
"repair_error": repair_error
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
def path_join(dir, name):
|
| 76 |
+
return os.path.join(dir, name)
|
| 77 |
+
|
| 78 |
+
def dict_sort_print(dic:dict, value:bool=True, reverse=True):
|
| 79 |
+
"""Sort a dictionary by value size and output"""
|
| 80 |
+
idx = 1 if value else 0
|
| 81 |
+
sorted_lis = sorted(dic.items(), key=lambda x: x[idx], reverse=reverse)
|
| 82 |
+
sorted_dic = {}
|
| 83 |
+
for key, value in sorted_lis:
|
| 84 |
+
sorted_dic[key] = value
|
| 85 |
+
print(json.dumps(sorted_dic, indent=4, ensure_ascii=False))
|
| 86 |
+
|
| 87 |
+
def clean_newlines(text: str) -> str:
|
| 88 |
+
"""
|
| 89 |
+
Clean lyric line breaks:
|
| 90 |
+
1. Keep line breaks after punctuation
|
| 91 |
+
2. Convert line breaks after non-punctuation → space
|
| 92 |
+
3. Fix extra spaces after English apostrophes
|
| 93 |
+
4. Merge redundant spaces
|
| 94 |
+
5. Preserve paragraph structure, ensure line breaks after punctuation
|
| 95 |
+
"""
|
| 96 |
+
if not text:
|
| 97 |
+
return ""
|
| 98 |
+
|
| 99 |
+
text = text.strip()
|
| 100 |
+
|
| 101 |
+
# First unify line breaks to \n
|
| 102 |
+
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
| 103 |
+
|
| 104 |
+
# Merge non-empty lines into one sentence (remove original line breaks first)
|
| 105 |
+
lines = [line.strip() for line in text.split('\n')]
|
| 106 |
+
text = ' '.join(line for line in lines if line)
|
| 107 |
+
|
| 108 |
+
# Add line break after sentence-ending punctuation (Chinese and English punctuation)
|
| 109 |
+
text = re.sub(r'([.,!?:;,。!?;])\s*', r'\1\n', text)
|
| 110 |
+
|
| 111 |
+
# Fix spaces after English apostrophes
|
| 112 |
+
text = re.sub(r"'\s+", "'", text)
|
| 113 |
+
|
| 114 |
+
# Merge redundant spaces
|
| 115 |
+
text = re.sub(r'[ \t]+', ' ', text)
|
| 116 |
+
|
| 117 |
+
# Remove leading and trailing spaces from lines
|
| 118 |
+
text = '\n'.join(line.strip() for line in text.split('\n'))
|
| 119 |
+
|
| 120 |
+
return text.strip()
|
| 121 |
+
|
| 122 |
+
# ===== Detection Functions =====
|
| 123 |
+
def is_ch_char(char:str):
|
| 124 |
+
"""Determine if a single character is a Chinese character"""
|
| 125 |
+
if len(char) != 1:
|
| 126 |
+
return False
|
| 127 |
+
|
| 128 |
+
# Unicode ranges for Chinese characters
|
| 129 |
+
# 1. Basic Chinese: 0x4E00-0x9FFF
|
| 130 |
+
# 2. Extension A: 0x3400-0x4DBF
|
| 131 |
+
# 3. Extension B: 0x20000-0x2A6DF
|
| 132 |
+
# 4. Extension C: 0x2A700-0x2B73F
|
| 133 |
+
# 5. Extension D: 0x2B740-0x2B81F
|
| 134 |
+
# 6. Extension E: 0x2B820-0x2CEAF
|
| 135 |
+
|
| 136 |
+
code = ord(char)
|
| 137 |
+
|
| 138 |
+
# Common check (covers most cases)
|
| 139 |
+
if 0x4E00 <= code <= 0x9FFF:
|
| 140 |
+
return True
|
| 141 |
+
# Extension A
|
| 142 |
+
if 0x3400 <= code <= 0x4DBF:
|
| 143 |
+
return True
|
| 144 |
+
# Other extensions not considered for now
|
| 145 |
+
|
| 146 |
+
return False
|
| 147 |
+
|
| 148 |
+
# ===== File Operations =====
|
| 149 |
+
|
| 150 |
+
def load_txt(path:str) -> str:
|
| 151 |
+
"""Load a file as plain text"""
|
| 152 |
+
with open(path, 'r') as file:
|
| 153 |
+
content = file.read()
|
| 154 |
+
return content
|
| 155 |
+
|
| 156 |
+
def load_json(path:str):
|
| 157 |
+
"""Load a JSON file"""
|
| 158 |
+
if not os.path.exists(path):
|
| 159 |
+
return {}
|
| 160 |
+
with open(path, 'r') as file:
|
| 161 |
+
data = json.load(file)
|
| 162 |
+
return data
|
| 163 |
+
|
| 164 |
+
def load_jsonl(path:str, limit=-1) -> list[dict]:
|
| 165 |
+
"""Load a JSONL file"""
|
| 166 |
+
data = []
|
| 167 |
+
with open(path, 'r') as file:
|
| 168 |
+
for id, line in tqdm(enumerate(file), desc=f"Loading {path}"):
|
| 169 |
+
if limit != -1 and id == limit:
|
| 170 |
+
break
|
| 171 |
+
data.append(json.loads(line))
|
| 172 |
+
return data
|
| 173 |
+
|
| 174 |
+
def save_json(data, path:str):
|
| 175 |
+
"""Save a JSON file"""
|
| 176 |
+
with open(path, 'w', encoding='utf-8') as file:
|
| 177 |
+
json.dump(data, file, ensure_ascii=False, indent=4)
|
| 178 |
+
|
| 179 |
+
def save_jsonl(data:list[dict], path:str, mode='w'):
|
| 180 |
+
"""Save a JSONL file"""
|
| 181 |
+
with open(path, mode, encoding='utf-8') as file:
|
| 182 |
+
for ele in tqdm(data, desc=f"Saving to {path}"):
|
| 183 |
+
json.dump(ele, file, ensure_ascii=False)
|
| 184 |
+
file.write("\n")
|
| 185 |
+
|
| 186 |
+
def audio_cut(input_path, mode:str, output_dir:str, segment_length:int=30000):
|
| 187 |
+
"""
|
| 188 |
+
Extract a segment of specified length from an audio file
|
| 189 |
+
- mode: Cut type (random / middle)
|
| 190 |
+
- output_dir: Output folder
|
| 191 |
+
- segment_length: Segment length (milliseconds)
|
| 192 |
+
"""
|
| 193 |
+
assert mode in ['random', 'middle']
|
| 194 |
+
|
| 195 |
+
# Check if file exists
|
| 196 |
+
if not os.path.exists(input_path):
|
| 197 |
+
raise FileNotFoundError(f"Audio file not found: {input_path}")
|
| 198 |
+
|
| 199 |
+
# Load audio file
|
| 200 |
+
audio = AudioSegment.from_file(input_path)
|
| 201 |
+
audio = audio.set_frame_rate(44100).set_channels(1) # Set sample rate and channels
|
| 202 |
+
audio_duration = len(audio) # Duration control
|
| 203 |
+
|
| 204 |
+
# If audio length is less than target segment length, use entire audio
|
| 205 |
+
if audio_duration <= segment_length:
|
| 206 |
+
print(f"Warning: Audio too short ({audio_duration}ms), using full audio: {input_path}")
|
| 207 |
+
segment = audio
|
| 208 |
+
else:
|
| 209 |
+
# Calculate slice position based on mode
|
| 210 |
+
if mode == "random":
|
| 211 |
+
# Random cut
|
| 212 |
+
max_start = max(0, audio_duration - segment_length)
|
| 213 |
+
start = random.randint(0, max_start)
|
| 214 |
+
end = start + segment_length
|
| 215 |
+
else:
|
| 216 |
+
# Cut from middle
|
| 217 |
+
middle_point = audio_duration // 2
|
| 218 |
+
start = max(0, middle_point - (segment_length // 2))
|
| 219 |
+
end = min(audio_duration, start + segment_length)
|
| 220 |
+
|
| 221 |
+
# If cutting from middle would exceed boundaries, adjust start position
|
| 222 |
+
if end > audio_duration:
|
| 223 |
+
end = audio_duration
|
| 224 |
+
start = end - segment_length
|
| 225 |
+
elif start < 0:
|
| 226 |
+
start = 0
|
| 227 |
+
end = segment_length
|
| 228 |
+
|
| 229 |
+
# Ensure slice range is valid
|
| 230 |
+
start = max(0, min(start, audio_duration))
|
| 231 |
+
end = max(0, min(end, audio_duration))
|
| 232 |
+
|
| 233 |
+
if start >= end:
|
| 234 |
+
raise ValueError(f"Invalid slice range: start={start}, end={end}, duration={audio_duration}")
|
| 235 |
+
|
| 236 |
+
# Execute slice
|
| 237 |
+
segment = audio[start:end]
|
| 238 |
+
|
| 239 |
+
# Generate output path
|
| 240 |
+
basename = pure_name(input_path)
|
| 241 |
+
output_path = os.path.join(output_dir, f"seg_{basename}.wav")
|
| 242 |
+
|
| 243 |
+
# Save segment
|
| 244 |
+
segment.export(
|
| 245 |
+
output_path,
|
| 246 |
+
format="wav",
|
| 247 |
+
codec="pcm_s16le", # 16-bit little-endian encoding
|
| 248 |
+
parameters=["-acodec", "pcm_s16le"] # ffmpeg parameters
|
| 249 |
+
)
|
| 250 |
+
return output_path
|
| 251 |
+
|
| 252 |
+
def format_meta(dir:str, show:bool=True) -> list[dict]:
|
| 253 |
+
"""Recursively get all audio paths (wav / mp3) in a folder and build JSONL"""
|
| 254 |
+
if not os.path.isdir(dir):
|
| 255 |
+
return []
|
| 256 |
+
dataset = []
|
| 257 |
+
if show:
|
| 258 |
+
for name in tqdm(os.listdir(dir), desc=f"Formating {dir}"):
|
| 259 |
+
path = os.path.join(dir, name)
|
| 260 |
+
if os.path.isdir(path):
|
| 261 |
+
dataset += format_meta(path, False)
|
| 262 |
+
elif name.endswith('.mp3') or name.endswith('.wav'):
|
| 263 |
+
dataset.append({"path": path})
|
| 264 |
+
else:
|
| 265 |
+
for name in os.listdir(dir):
|
| 266 |
+
path = os.path.join(dir, name)
|
| 267 |
+
if os.path.isdir(path):
|
| 268 |
+
dataset += format_meta(path, False)
|
| 269 |
+
elif name.endswith('.mp3') or name.endswith('.wav'):
|
| 270 |
+
dataset.append({"path": path})
|
| 271 |
+
return dataset
|
| 272 |
+
|
| 273 |
+
def dup_remove(raw_data:list[dict], save_path:str, key:str, seg:str):
|
| 274 |
+
"""
|
| 275 |
+
Remove already generated items from dataset
|
| 276 |
+
- key is the primary key in raw dataset, foreign key in save
|
| 277 |
+
- seg is the target field
|
| 278 |
+
"""
|
| 279 |
+
if not os.path.exists(save_path):
|
| 280 |
+
print(f"Dup num: 0")
|
| 281 |
+
return raw_data
|
| 282 |
+
save_data = load_jsonl(save_path)
|
| 283 |
+
keys = set()
|
| 284 |
+
for ele in tqdm(save_data, desc="Constructing Dup Set"):
|
| 285 |
+
if seg in ele:
|
| 286 |
+
keys.add(ele[key])
|
| 287 |
+
rest_data = []
|
| 288 |
+
dup_count = 0
|
| 289 |
+
for ele in tqdm(raw_data, desc="Checking Dup"):
|
| 290 |
+
if ele[key] not in keys:
|
| 291 |
+
rest_data.append(ele)
|
| 292 |
+
else:
|
| 293 |
+
dup_count += 1
|
| 294 |
+
print(f"Dup num: {dup_count}")
|
| 295 |
+
return rest_data
|
| 296 |
+
|
| 297 |
+
def tar_size_check(data_dir:str, subfixes:list[str], per:int, max_size:int):
|
| 298 |
+
"""
|
| 299 |
+
Determine the number of files that can fit in a block before compression (assuming uniform file sizes)
|
| 300 |
+
- data_dir: Folder to compress
|
| 301 |
+
- subfixes: File suffixes to compress (e.g., .mp3)
|
| 302 |
+
- per: Check every N files on average
|
| 303 |
+
- max_size: Maximum limit in GB
|
| 304 |
+
"""
|
| 305 |
+
names = sorted(list(os.listdir(data_dir)))
|
| 306 |
+
count = 0
|
| 307 |
+
size_sum = 0
|
| 308 |
+
for name in tqdm(names, desc="Size Checking"):
|
| 309 |
+
path = os.path.join(data_dir, name)
|
| 310 |
+
subfix = os.path.splitext(name)[1]
|
| 311 |
+
if subfix not in subfixes:
|
| 312 |
+
continue
|
| 313 |
+
count += 1
|
| 314 |
+
size_sum += os.path.getsize(path)
|
| 315 |
+
if count % per == 0:
|
| 316 |
+
gb_size = size_sum / 1024 / 1024 / 1024
|
| 317 |
+
if gb_size > max_size:
|
| 318 |
+
break
|
| 319 |
+
print(f"Count: {count}, Size: {gb_size:.2f}GB")
|
| 320 |
+
|
| 321 |
+
def tar_dir(
|
| 322 |
+
data_dir:str,
|
| 323 |
+
subfixes:list[str],
|
| 324 |
+
save_dir:str,
|
| 325 |
+
group_size:int,
|
| 326 |
+
tmp_dir:str,
|
| 327 |
+
mark:str,
|
| 328 |
+
max_workers:int=10,
|
| 329 |
+
):
|
| 330 |
+
"""Compress files in a directory in chunks (non-recursive)"""
|
| 331 |
+
names = sorted(list(os.listdir(data_dir)))
|
| 332 |
+
file_num = len(names)
|
| 333 |
+
for i in range(0, file_num, group_size):
|
| 334 |
+
names_subset = names[i:i+group_size]
|
| 335 |
+
size_sum = 0
|
| 336 |
+
name_path = os.path.join(tmp_dir, f"name_{i}_{mark}")
|
| 337 |
+
with open(name_path, 'w', encoding='utf-8') as file:
|
| 338 |
+
for name in tqdm(names_subset, desc=f"Counting Block {i}"):
|
| 339 |
+
path = os.path.join(data_dir, name)
|
| 340 |
+
subfix = os.path.splitext(path)[1]
|
| 341 |
+
if subfix not in subfixes:
|
| 342 |
+
continue
|
| 343 |
+
file.write("./" + name + "\n")
|
| 344 |
+
size_sum += os.path.getsize(path)
|
| 345 |
+
gb_size = size_sum / 1024 / 1024 / 1024
|
| 346 |
+
print(f"Zipping block {i+1}, size: {gb_size:.2f}GB")
|
| 347 |
+
|
| 348 |
+
tar_cmd = [
|
| 349 |
+
'tar',
|
| 350 |
+
'--no-recursion',
|
| 351 |
+
'--files-from', str(name_path),
|
| 352 |
+
'-cf', '-'
|
| 353 |
+
]
|
| 354 |
+
pigz_cmd = ['pigz', '-p', str(max_workers), '-c']
|
| 355 |
+
|
| 356 |
+
tar_process = subprocess.Popen(tar_cmd, stdout=subprocess.PIPE, cwd=data_dir)
|
| 357 |
+
pigz_process = subprocess.Popen(pigz_cmd, stdin=tar_process.stdout, stdout=subprocess.PIPE, cwd=data_dir)
|
| 358 |
+
|
| 359 |
+
save_path = os.path.join(save_dir, f"block_{i}_{mark}.tar.gz")
|
| 360 |
+
with open(save_path, 'wb') as out_file:
|
| 361 |
+
while True:
|
| 362 |
+
data = pigz_process.stdout.read(4096)
|
| 363 |
+
if not data:
|
| 364 |
+
break
|
| 365 |
+
out_file.write(data)
|
| 366 |
+
|
| 367 |
+
tar_process.wait()
|
| 368 |
+
pigz_process.wait()
|
| 369 |
+
|
| 370 |
+
if tar_process.returncode == 0 and pigz_process.returncode == 0:
|
| 371 |
+
print(f"Compression completed: {save_path}")
|
| 372 |
+
else:
|
| 373 |
+
print(f"Compression failed: tar return code={tar_process.returncode}, pigz return code={pigz_process.returncode}")
|
| 374 |
+
|
| 375 |
+
def music_avg_size(dir:str):
|
| 376 |
+
"""Average music size (MB), length (s)"""
|
| 377 |
+
dataset = format_meta(dir)
|
| 378 |
+
dataset = dataset[:50]
|
| 379 |
+
size_sum = 0
|
| 380 |
+
length_sum = 0
|
| 381 |
+
for ele in tqdm(dataset, desc=f"Counting Music Size in {dir}"):
|
| 382 |
+
path = ele['path']
|
| 383 |
+
audio = AudioSegment.from_file(path)
|
| 384 |
+
length_sum += len(audio) / 1000.0
|
| 385 |
+
size_sum += os.path.getsize(path)
|
| 386 |
+
size_avg = size_sum / len(dataset) / 1024 / 1024
|
| 387 |
+
length_avg = length_sum / len(dataset)
|
| 388 |
+
return size_avg, length_avg
|
| 389 |
+
|
| 390 |
+
def get_sample(path:str, save_path:str="tmp.jsonl", num:int=100):
|
| 391 |
+
"""Get N records from a JSONL file"""
|
| 392 |
+
if not os.path.exists(path):
|
| 393 |
+
return
|
| 394 |
+
if path.endswith(".jsonl"):
|
| 395 |
+
dataset = load_jsonl(path)
|
| 396 |
+
elif path.endswith(".json"):
|
| 397 |
+
dataset = load_json(path)
|
| 398 |
+
else:
|
| 399 |
+
print(f"Unsupport file: {path}")
|
| 400 |
+
return
|
| 401 |
+
sub_dataset = random.sample(dataset, num)
|
| 402 |
+
save_jsonl(sub_dataset, save_path)
|
| 403 |
+
|
| 404 |
+
def _get_field_one(path:str, field:str):
|
| 405 |
+
"""Process data from one path"""
|
| 406 |
+
with open(path, 'r') as file:
|
| 407 |
+
data = json.load(file)
|
| 408 |
+
new_data = {
|
| 409 |
+
"id": f"{data['song_id']}_{data['track_index']}",
|
| 410 |
+
field: data[field]
|
| 411 |
+
}
|
| 412 |
+
return new_data
|
| 413 |
+
|
| 414 |
+
def get_field_suno(dir:str, save_path:str, field:str, max_workers:int=8):
|
| 415 |
+
"""Extract a specific field from scattered JSON files in suno"""
|
| 416 |
+
paths = []
|
| 417 |
+
for name in tqdm(os.listdir(dir), desc="Getting names"):
|
| 418 |
+
if not name.endswith(".json"):
|
| 419 |
+
continue
|
| 420 |
+
paths.append(os.path.join(dir, name))
|
| 421 |
+
|
| 422 |
+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
| 423 |
+
futures = [executor.submit(_get_field_one, path, field) for path in paths]
|
| 424 |
+
with open(save_path, 'w', encoding='utf-8') as file:
|
| 425 |
+
with tqdm(total=len(paths), desc="Processing the JSONs") as pbar:
|
| 426 |
+
for future in as_completed(futures):
|
| 427 |
+
result = future.result()
|
| 428 |
+
json.dump(result, file, ensure_ascii=False)
|
| 429 |
+
file.write("\n")
|
| 430 |
+
pbar.update(1)
|
| 431 |
+
|
| 432 |
+
def find_json(dir:str) -> list[str]:
|
| 433 |
+
"""Find JSONL / JSON files in a folder"""
|
| 434 |
+
names = []
|
| 435 |
+
for name in tqdm(os.listdir(dir), desc="Finding JSON/JSONL"):
|
| 436 |
+
if name.endswith(".json") or name.endswith(".jsonl"):
|
| 437 |
+
names.append(name)
|
| 438 |
+
return names
|
| 439 |
+
|
| 440 |
+
def show_dir(dir:str):
|
| 441 |
+
"""Display all contents in a directory"""
|
| 442 |
+
if not os.path.isdir(dir):
|
| 443 |
+
return
|
| 444 |
+
for name in os.listdir(dir):
|
| 445 |
+
print(name)
|
| 446 |
+
|
| 447 |
+
def _convert_mp3(path:str, dir:str):
|
| 448 |
+
"""Process a single audio file"""
|
| 449 |
+
purename = pure_name(path)
|
| 450 |
+
output_path = os.path.join(dir, purename + ".mp3")
|
| 451 |
+
if os.path.exists(output_path):
|
| 452 |
+
# Already completed
|
| 453 |
+
return "pass"
|
| 454 |
+
try:
|
| 455 |
+
audio = AudioSegment.from_file(path)
|
| 456 |
+
except Exception:
|
| 457 |
+
# Failed to read file
|
| 458 |
+
print(f"fail to load {path}")
|
| 459 |
+
return "fail"
|
| 460 |
+
audio.export(output_path, format='mp3')
|
| 461 |
+
return "finish"
|
| 462 |
+
|
| 463 |
+
def convert_mp3(meta_path:str, dir:str, max_workers:int=10):
|
| 464 |
+
"""Convert all specified audio files to mp3 and save in specified directory"""
|
| 465 |
+
os.makedirs(dir, exist_ok=True)
|
| 466 |
+
dataset = load_jsonl(meta_path)
|
| 467 |
+
pass_num = 0
|
| 468 |
+
finish_num = 0
|
| 469 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 470 |
+
futures = [executor.submit(_convert_mp3, ele['path'], dir) for ele in dataset]
|
| 471 |
+
with tqdm(total=len(dataset), desc=f"Converting {meta_path}") as pbar:
|
| 472 |
+
for future in as_completed(futures):
|
| 473 |
+
res = future.result()
|
| 474 |
+
if res == "pass":
|
| 475 |
+
pass_num += 1
|
| 476 |
+
else:
|
| 477 |
+
finish_num += 1
|
| 478 |
+
pbar.update(1)
|
| 479 |
+
print(f"Finish {finish_num}, Pass {pass_num}")
|
| 480 |
+
|
| 481 |
+
# ===== GPU and Models =====
|
| 482 |
+
|
| 483 |
+
def get_free_gpu() -> int:
|
| 484 |
+
"""Return the GPU ID with the least memory usage"""
|
| 485 |
+
cmd = "nvidia-smi --query-gpu=index,memory.free --format=csv,noheader,nounits"
|
| 486 |
+
result = subprocess.check_output(cmd.split()).decode().strip().split("\n")
|
| 487 |
+
|
| 488 |
+
free_list = []
|
| 489 |
+
for line in result:
|
| 490 |
+
idx, free_mem = line.split(",")
|
| 491 |
+
free_list.append((int(idx), int(free_mem))) # (GPU id, free memory MiB)
|
| 492 |
+
|
| 493 |
+
# Sort by remaining memory
|
| 494 |
+
free_list.sort(key=lambda x: x[1], reverse=True)
|
| 495 |
+
return free_list[0][0]
|
| 496 |
+
|
| 497 |
+
# ===== Data Analysis =====
|
| 498 |
+
|
| 499 |
+
def compose_analyze(dataset:list[dict]):
|
| 500 |
+
"""Statistical analysis of music structure composition"""
|
| 501 |
+
# Label count
|
| 502 |
+
labels = defaultdict(int)
|
| 503 |
+
for ele in tqdm(dataset):
|
| 504 |
+
segments = ele['segments']
|
| 505 |
+
for segment in segments:
|
| 506 |
+
label = segment['label']
|
| 507 |
+
labels[label] += 1
|
| 508 |
+
print(f"Number of labels: {len(labels)}")
|
| 509 |
+
print(dict_sort_print(labels))
|
| 510 |
+
|
| 511 |
+
# Different combinations
|
| 512 |
+
label_combs = defaultdict(int)
|
| 513 |
+
for ele in tqdm(dataset):
|
| 514 |
+
segments = ele['segments']
|
| 515 |
+
labels = []
|
| 516 |
+
for segment in segments:
|
| 517 |
+
label = segment['label']
|
| 518 |
+
labels.append(label)
|
| 519 |
+
if len(labels) == 0:
|
| 520 |
+
continue
|
| 521 |
+
label_comb = " | ".join(labels)
|
| 522 |
+
label_combs[label_comb] += 1
|
| 523 |
+
print(f"Number of combinations: {len(label_combs)}")
|
| 524 |
+
print(dict_sort_print(label_combs))
|
| 525 |
+
|
| 526 |
+
def _filter_tag(content:str) -> list[str]:
|
| 527 |
+
"""Split and format tag fields"""
|
| 528 |
+
tags = []
|
| 529 |
+
raws = re.split(r'[,,.]', content)
|
| 530 |
+
for raw in raws:
|
| 531 |
+
raw = raw.strip().lower() # Remove spaces and convert to lowercase
|
| 532 |
+
if raw == "":
|
| 533 |
+
continue
|
| 534 |
+
seg_pos = raw.find(":")
|
| 535 |
+
if seg_pos != -1:
|
| 536 |
+
# If colon exists, only take the part after it
|
| 537 |
+
tag = raw[seg_pos+1:].strip()
|
| 538 |
+
else:
|
| 539 |
+
tag = raw
|
| 540 |
+
tags.append(tag)
|
| 541 |
+
return tags
|
| 542 |
+
|
| 543 |
+
def tags_analyze(dataset:list[dict]):
|
| 544 |
+
"""Song tag analysis"""
|
| 545 |
+
tag_count = defaultdict(int)
|
| 546 |
+
for ele in tqdm(dataset, desc="Tag analyzing"):
|
| 547 |
+
tags = _filter_tag(ele['style'])
|
| 548 |
+
for tag in tags:
|
| 549 |
+
tag_count[tag] += 1
|
| 550 |
+
print(f"Number of tags: {len(tag_count.keys())}")
|
| 551 |
+
print(dict_sort_print(tag_count))
|