Jacong commited on
Commit
aa9be1e
·
verified ·
1 Parent(s): 35380aa

Upload 96 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +11 -0
  2. LICENSE +21 -0
  3. README.md +80 -12
  4. assets/intro.jpg +3 -0
  5. baseline_generate/ace_step/convert.py +160 -0
  6. baseline_generate/ace_step/infer.py +122 -0
  7. baseline_generate/diffrhythm2/batch_inference.sh +57 -0
  8. baseline_generate/diffrhythm2/batch_inference_en.sh +57 -0
  9. baseline_generate/diffrhythm2/inference.py +294 -0
  10. baseline_generate/diffrhythm2/inference.sh +10 -0
  11. baseline_generate/diffrhythm2/scripts/__pycache__/proce_song_enprompt.cpython-311.pyc +0 -0
  12. baseline_generate/diffrhythm2/scripts/__pycache__/proce_song_enprompt_justtag.cpython-311.pyc +0 -0
  13. baseline_generate/diffrhythm2/scripts/proce_song.py +92 -0
  14. baseline_generate/diffrhythm2/scripts/proce_song_enprompt.py +145 -0
  15. baseline_generate/diffrhythm2/scripts/proce_song_enprompt_justtag.py +143 -0
  16. baseline_generate/levo/__pycache__/generate.cpython-311.pyc +0 -0
  17. baseline_generate/levo/convert.py +110 -0
  18. baseline_generate/levo/generate.py +591 -0
  19. baseline_generate/mureka_o2/__pycache__/generate.cpython-311.pyc +0 -0
  20. baseline_generate/mureka_o2/generate.py +390 -0
  21. baseline_generate/suno/__pycache__/suno_4_5.cpython-311.pyc +0 -0
  22. baseline_generate/suno/__pycache__/suno_5.cpython-311.pyc +0 -0
  23. baseline_generate/suno/config.py +70 -0
  24. baseline_generate/suno/suno_4_5.py +766 -0
  25. baseline_generate/suno/suno_5.py +768 -0
  26. baseline_generate/yue/__pycache__/infer_batch.cpython-311.pyc +0 -0
  27. baseline_generate/yue/batch.sh +55 -0
  28. baseline_generate/yue/codecmanipulator.py +204 -0
  29. baseline_generate/yue/infer_batch.py +904 -0
  30. baseline_generate/yue/mmtokenizer.py +367 -0
  31. data_pipeline/lyrics_gene/__pycache__/filter_all_cn.cpython-311.pyc +0 -0
  32. data_pipeline/lyrics_gene/__pycache__/filter_all_en.cpython-311.pyc +0 -0
  33. data_pipeline/lyrics_gene/__pycache__/gen_lyrics_cn.cpython-311.pyc +0 -0
  34. data_pipeline/lyrics_gene/filter_all_cn.py +272 -0
  35. data_pipeline/lyrics_gene/filter_all_en.py +256 -0
  36. data_pipeline/lyrics_gene/gen_lyrics_cn.py +568 -0
  37. data_pipeline/lyrics_gene/gen_lyrics_en.py +577 -0
  38. data_pipeline/meta_process/convert_convs.py +98 -0
  39. data_pipeline/meta_process/convert_lyrics.py +180 -0
  40. data_pipeline/meta_process/convert_messages.py +593 -0
  41. data_pipeline/meta_process/convert_segments.py +93 -0
  42. data_pipeline/meta_process/evaluate_polyphones.py +62 -0
  43. data_pipeline/meta_process/filter.py +46 -0
  44. data_pipeline/meta_process/main.py +77 -0
  45. data_pipeline/meta_process/meta_endpoints.py +118 -0
  46. data_pipeline/meta_process/meta_lang.py +125 -0
  47. data_pipeline/meta_process/meta_phonemes.py +283 -0
  48. data_pipeline/meta_process/meta_tags.py +124 -0
  49. data_pipeline/meta_process/meta_vocal.py +141 -0
  50. data_pipeline/meta_process/my_tool.py +551 -0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/intro.jpg filter=lfs diff=lfs merge=lfs -text
37
+ examples/muse_outputs/main_0.6b_0.mp3 filter=lfs diff=lfs merge=lfs -text
38
+ examples/muse_outputs/main_0.6b_1.mp3 filter=lfs diff=lfs merge=lfs -text
39
+ examples/muse_outputs/main_1.7b_2.mp3 filter=lfs diff=lfs merge=lfs -text
40
+ examples/muse_outputs/main_8b_3.mp3 filter=lfs diff=lfs merge=lfs -text
41
+ examples/train_inputs/suno_cn_0.mp3 filter=lfs diff=lfs merge=lfs -text
42
+ examples/train_inputs/suno_cn_1.mp3 filter=lfs diff=lfs merge=lfs -text
43
+ examples/train_inputs/suno_en_2.mp3 filter=lfs diff=lfs merge=lfs -text
44
+ examples/train_inputs/suno_en_3.mp3 filter=lfs diff=lfs merge=lfs -text
45
+ train/train_demo.jsonl filter=lfs diff=lfs merge=lfs -text
46
+ train/val.jsonl filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 yuhui1038
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,80 @@
1
- ---
2
- title: Muse
3
- emoji: 🐢
4
- colorFrom: red
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 6.3.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Muse: Towards Reproducible Long-Form Song Generation with Fine-Grained Style Control
2
+
3
+ <p align="center">
4
+ 📄 <a href="https://arxiv.org/abs/2601.03973">Paper</a> • 📊 <a href="https://huggingface.co/datasets/bolshyC/Muse">Dataset</a> • 🤖 <a href="https://huggingface.co/bolshyC/models">Model</a> • 📚 <a href="#citation">Citation</a>
5
+ </p>
6
+
7
+ This repository is the official repository for "Muse: Towards Reproducible Long-Form Song Generation with Fine-Grained Style Control". In this repository, we provide the Muse model, training and inference scripts, pretrained checkpoints, and evaluation pipelines.
8
+
9
+ ## News and Updates
10
+
11
+ * **2026.01.11 🔥**: We are excited to announce that all datasets and models are now fully open-sourced! 🎶 The complete training dataset (116k songs), pretrained model weights, training and evaluation code, and data pipeline are publicly available.
12
+
13
+ ## Installation
14
+
15
+ **Requirements**: Python 3.10 is required.
16
+
17
+ To set up the environment for Muse:
18
+
19
+ - **For training**: Install the training framework:
20
+ ```bash
21
+ pip install ms-swift -U
22
+ ```
23
+ - **For inference**: Install vLLM:
24
+ ```bash
25
+ pip install vllm
26
+ ```
27
+ - **For audio encoding/decoding**: Some dependencies (e.g., `av`) require system-level packages. On Ubuntu/Debian, install FFmpeg 4.4+ first:
28
+ ```bash
29
+ sudo apt-get update
30
+ sudo apt-get install -y software-properties-common
31
+ sudo add-apt-repository ppa:savoury1/ffmpeg4 -y
32
+ sudo apt-get update
33
+ sudo apt-get install -y pkg-config ffmpeg libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
34
+ ```
35
+ We recommend creating a new conda environment with Python 3.10. **Note**: Since `omegaconf==2.0.6` is required and has compatibility issues with pip 24.1+, you need to downgrade pip first:
36
+ ```bash
37
+ pip install "pip<24.1"
38
+ ```
39
+ Then install dependencies:
40
+ ```bash
41
+ pip install --default-timeout=1000 -r requirements_mucodec.txt
42
+ ```
43
+ For more details, please refer to the [MuCodec](https://github.com/tencent-ailab/MuCodec) official repository.
44
+
45
+ - **For data pipeline and evaluation**: If you need to run data processing scripts (lyrics generation, metadata processing) or evaluation scripts, install additional dependencies:
46
+ ```bash
47
+ pip install -r requirements_data_eval.txt
48
+ ```
49
+
50
+ ## Repository Structure
51
+
52
+ This repository contains the following main directories:
53
+
54
+ - **`train/`**: Training scripts and utilities for fine-tuning the Muse model. See [`train/README.md`](train/README.md) for details.
55
+ - **`infer/`**: Inference scripts for generating music with the Muse model. See [`infer/README.md`](infer/README.md) for details.
56
+ - **`eval_pipeline/`**: Evaluation scripts for assessing model performance (Mulan-T, PER, AudioBox, SongEval, etc.).
57
+ - **`data_pipeline/`**: Scripts for building and processing training data, including lyrics generation, metadata processing, and music generation utilities.
58
+
59
+ ## Model Architecture
60
+
61
+ <p align="center">
62
+ <img src="assets/intro.jpg" width="800"/>
63
+ </p>
64
+
65
+ ## Acknowledgments
66
+
67
+ We thank [Qwen3](https://github.com/QwenLM/Qwen3) for providing the base language model, [ms-swift](https://github.com/modelscope/ms-swift) for the training framework, and [MuCodec](https://github.com/tencent-ailab/MuCodec) for discrete audio tokenization.
68
+
69
+ ## Citation
70
+
71
+ If you find our work useful, please cite our paper:
72
+
73
+ ```bibtex
74
+ @article{jiang2026muse,
75
+ title={Muse: Towards Reproducible Long-Form Song Generation with Fine-Grained Style Control},
76
+ author={Jiang, Changhao and Chen, Jiahao and Xiang, Zhenghao and Yang, Zhixiong and Wang, Hanchen and Zhuang, Jiabao and Che, Xinmeng and Sun, Jiajun and Li, Hui and Cao, Yifei and others},
77
+ journal={arXiv preprint arXiv:2601.03973},
78
+ year={2026}
79
+ }
80
+ ```
assets/intro.jpg ADDED

Git LFS Details

  • SHA256: 8dccf43f652372948e48c47afabb0d979823788ae00a4019b8d8d215821284eb
  • Pointer size: 131 Bytes
  • Size of remote file: 166 kB
baseline_generate/ace_step/convert.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert data to ACE-STEP acceptable format
3
+ """
4
+
5
+ import re
6
+ import json
7
+ import random
8
+ from tqdm import tqdm
9
+
10
+ random.seed(42)
11
+
12
+ def load_jsonl(path:str) -> list[dict]:
13
+ data = []
14
+ with open(path, 'r') as file:
15
+ for line in tqdm(file, desc=f"Loading {path}"):
16
+ data.append(json.loads(line))
17
+ return data
18
+
19
+ def save_jsonl(data:list, path:str):
20
+ with open(path, 'w', encoding='utf-8') as file:
21
+ for ele in tqdm(data, desc=f"Saving {path}"):
22
+ json.dump(ele, file, ensure_ascii=False)
23
+ file.write("\n")
24
+
25
+ START_STR = "Please generate a song in the following style:"
26
+ END_STR = "\nNext, I will tell you the requirements and lyrics"
27
+
28
+ def process_tag(content:str) -> str:
29
+ """Process segment label"""
30
+ # Extract label
31
+ end = content.find("[desc:")
32
+ tag = content[1:end-1]
33
+ # Lowercase & remove numbers & remove parentheses
34
+ tag = tag.lower()
35
+ tag = re.sub(r'\d+', '', tag)
36
+ tag = re.sub(r'\([^)]*\)', '', tag).strip()
37
+ if tag == "pre-chorus":
38
+ tag = "chorus"
39
+ return f"[{tag}]"
40
+
41
+ def process_lyrics(content:str) -> str:
42
+ """Process segment lyrics"""
43
+ # Extract lyrics
44
+ start = content.find("[lyrics:\n")
45
+ if start == -1:
46
+ return ""
47
+ end = content.find("][phoneme:")
48
+ lyric = content[start+len("[lyrics:\n"):end]
49
+
50
+ # Punctuation conversion
51
+ pattern = r'[,。",:;&—‘\'.\]\[()?\n-]'
52
+ lyric = re.sub(pattern, '\n', lyric)
53
+ while lyric.find("\n\n") != -1:
54
+ lyric = lyric.replace("\n\n", "\n")
55
+ if lyric.endswith('\n'):
56
+ lyric = lyric[:-1]
57
+ return lyric
58
+
59
+ def has_chinese(text) -> bool:
60
+ for char in text:
61
+ if '\u4e00' <= char <= '\u9fff': # Basic Chinese characters
62
+ return True
63
+ return False
64
+
65
+ def process_duration(lyrics:str):
66
+ if has_chinese(lyrics):
67
+ lyrics = lyrics.replace("\n", "")
68
+ length = len(lyrics)
69
+ else:
70
+ lyrics = lyrics.replace("\n", " ")
71
+ length = len(lyrics.split())
72
+ duration = random.randint(int(length * 0.4), int(length * 0.7))
73
+ return duration
74
+
75
+ def process_one(messages:list[dict]):
76
+ """Process a conversation messages into input format, return gt_lyric and descriptions"""
77
+ # Overall style
78
+ style:str = messages[0]['content']
79
+ start = style.find(START_STR)
80
+ end = style.find(END_STR)
81
+ descriptions = style[start+len(START_STR):end]
82
+
83
+ # Line-by-line lyrics
84
+ all_lyrics = "[intro]\n\n"
85
+ pure_lyrics = ""
86
+ for message in messages[1:]:
87
+ if message['role'] == "assistant":
88
+ continue
89
+ content = message['content']
90
+ # Segment label
91
+ tag = process_tag(content)
92
+ # Segment lyrics
93
+ lyric = process_lyrics(content)
94
+ all_lyrics += f"{tag}\n{lyric}\n\n"
95
+ pure_lyrics += lyric
96
+ all_lyrics = all_lyrics[:-2]
97
+
98
+ # Duration
99
+ duration = process_duration(pure_lyrics)
100
+
101
+ obj = {
102
+ "prompt": descriptions,
103
+ "lyrics": all_lyrics,
104
+ "audio_duration": duration,
105
+ "infer_step": 60,
106
+ "guidance_scale": 15,
107
+ "scheduler_type": "euler",
108
+ "cfg_type": "apg",
109
+ "omega_scale": 10,
110
+ "guidance_interval": 0.5,
111
+ "guidance_interval_decay": 0,
112
+ "min_guidance_scale": 3,
113
+ "use_erg_tag": True,
114
+ "use_erg_lyric": True,
115
+ "use_erg_diffusion": True,
116
+ "oss_steps": [],
117
+ "actual_seeds": [
118
+ 3299954530
119
+ ]
120
+ }
121
+ return obj
122
+
123
+ def main():
124
+ path = "xxx/ACE-Step/data/inputs/messages.jsonl"
125
+ dataset = load_jsonl(path)
126
+
127
+ for id, ele in tqdm(enumerate(dataset), desc="Processing"):
128
+ messages = ele['messages']
129
+ data = process_one(messages)
130
+ path = f"./data/inputs/test_{id}.jsonl"
131
+ with open(path, 'w', encoding='utf-8') as file:
132
+ json.dump(data, file, ensure_ascii=False, indent=4)
133
+
134
+ def load_jsonl(path):
135
+ dataset = []
136
+ with open(path, 'r') as file:
137
+ for line in file:
138
+ dataset.append(json.loads(line))
139
+ return dataset
140
+
141
+ def save_jsonl(dataset, path):
142
+ with open(path, 'w', encoding='utf-8') as file:
143
+ for ele in dataset:
144
+ json.dump(ele, file, ensure_ascii=False)
145
+ file.write("\n")
146
+
147
+ if __name__ == "__main__":
148
+ # main()
149
+ dataset = load_jsonl("./data/outputs/lyrics_params.jsonl")
150
+ for ele in dataset:
151
+ path = ele['audio_path']
152
+ ele['extra'] = int(path[len("./data/outputs/test_"):-len(".wav")])
153
+ sorted_data = sorted(dataset, key=lambda x: x['extra'])
154
+
155
+ save_path = "./data/outputs/lyrics_params_.jsonl"
156
+ with open(save_path, 'w', encoding='utf-8') as file:
157
+ for ele in sorted_data:
158
+ del ele['extra']
159
+ json.dump(ele, file, ensure_ascii=False)
160
+ file.write("\n")
baseline_generate/ace_step/infer.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ import os
3
+ import json
4
+ from acestep.pipeline_ace_step import ACEStepPipeline
5
+ from acestep.data_sampler import DataSampler
6
+
7
+
8
+ def sample_data(json_data):
9
+ return (
10
+ json_data["audio_duration"],
11
+ json_data["prompt"],
12
+ json_data["lyrics"],
13
+ json_data["infer_step"],
14
+ json_data["guidance_scale"],
15
+ json_data["scheduler_type"],
16
+ json_data["cfg_type"],
17
+ json_data["omega_scale"],
18
+ ", ".join(map(str, json_data["actual_seeds"])),
19
+ json_data["guidance_interval"],
20
+ json_data["guidance_interval_decay"],
21
+ json_data["min_guidance_scale"],
22
+ json_data["use_erg_tag"],
23
+ json_data["use_erg_lyric"],
24
+ json_data["use_erg_diffusion"],
25
+ ", ".join(map(str, json_data["oss_steps"])),
26
+ json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
27
+ (
28
+ json_data["guidance_scale_lyric"]
29
+ if "guidance_scale_lyric" in json_data
30
+ else 0.0
31
+ ),
32
+ )
33
+
34
+
35
+ @click.command()
36
+ @click.option(
37
+ "--checkpoint_path", type=str, default="", help="Path to the checkpoint directory"
38
+ )
39
+ @click.option("--bf16", type=bool, default=True, help="Whether to use bfloat16")
40
+ @click.option(
41
+ "--torch_compile", type=bool, default=False, help="Whether to use torch compile"
42
+ )
43
+ @click.option(
44
+ "--cpu_offload", type=bool, default=False, help="Whether to use CPU offloading (only load current stage's model to GPU)"
45
+ )
46
+ @click.option(
47
+ "--overlapped_decode", type=bool, default=False, help="Whether to use overlapped decoding (run dcae and vocoder using sliding windows)"
48
+ )
49
+ @click.option("--device_id", type=int, default=0, help="Device ID to use")
50
+ @click.option("--output_path", type=str, default=None, help="Path to save the output")
51
+ def main(checkpoint_path, bf16, torch_compile, cpu_offload, overlapped_decode, device_id, output_path):
52
+ # os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
53
+
54
+ model_demo = ACEStepPipeline(
55
+ checkpoint_dir=checkpoint_path,
56
+ dtype="bfloat16" if bf16 else "float32",
57
+ torch_compile=torch_compile,
58
+ cpu_offload=cpu_offload,
59
+ overlapped_decode=overlapped_decode
60
+ )
61
+ print(model_demo)
62
+
63
+ data_sampler = DataSampler()
64
+
65
+ inputs_dir = "./data/inputs"
66
+ for id, name in enumerate(os.listdir(inputs_dir)):
67
+ if not name.startswith("test"):
68
+ continue
69
+ path = os.path.join(inputs_dir, name)
70
+ with open(path, 'r') as file:
71
+ json_data = json.load(file)
72
+ json_data = sample_data(json_data)
73
+
74
+ pure_name = os.path.splitext(name)[0]
75
+ output_path = f"./data/outputs/{pure_name}.wav"
76
+ if os.path.exists(output_path):
77
+ continue
78
+ (
79
+ audio_duration,
80
+ prompt,
81
+ lyrics,
82
+ infer_step,
83
+ guidance_scale,
84
+ scheduler_type,
85
+ cfg_type,
86
+ omega_scale,
87
+ manual_seeds,
88
+ guidance_interval,
89
+ guidance_interval_decay,
90
+ min_guidance_scale,
91
+ use_erg_tag,
92
+ use_erg_lyric,
93
+ use_erg_diffusion,
94
+ oss_steps,
95
+ guidance_scale_text,
96
+ guidance_scale_lyric,
97
+ ) = json_data
98
+
99
+ model_demo(
100
+ audio_duration=audio_duration,
101
+ prompt=prompt,
102
+ lyrics=lyrics,
103
+ infer_step=infer_step,
104
+ guidance_scale=guidance_scale,
105
+ scheduler_type=scheduler_type,
106
+ cfg_type=cfg_type,
107
+ omega_scale=omega_scale,
108
+ manual_seeds=manual_seeds,
109
+ guidance_interval=guidance_interval,
110
+ guidance_interval_decay=guidance_interval_decay,
111
+ min_guidance_scale=min_guidance_scale,
112
+ use_erg_tag=use_erg_tag,
113
+ use_erg_lyric=use_erg_lyric,
114
+ use_erg_diffusion=use_erg_diffusion,
115
+ oss_steps=oss_steps,
116
+ guidance_scale_text=guidance_scale_text,
117
+ guidance_scale_lyric=guidance_scale_lyric,
118
+ save_path=output_path,
119
+ )
120
+
121
+ if __name__ == "__main__":
122
+ main()
baseline_generate/diffrhythm2/batch_inference.sh ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ # Navigate to script directory to ensure relative paths are consistent
5
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
6
+ cd "$SCRIPT_DIR"
7
+
8
+ SONG_DIR="$SCRIPT_DIR/example/zh_songs"
9
+
10
+ if [ ! -d "$SONG_DIR" ]; then
11
+ echo "Song directory not found: $SONG_DIR"
12
+ exit 1
13
+ fi
14
+
15
+ # Collect all song_*.jsonl files
16
+ shopt -s nullglob
17
+ SONG_FILES=("$SONG_DIR"/song_*.jsonl)
18
+ shopt -u nullglob
19
+
20
+ if [ ${#SONG_FILES[@]} -eq 0 ]; then
21
+ echo "No song_*.jsonl files in song directory"
22
+ exit 0
23
+ fi
24
+
25
+ export PYTHONPATH="${PYTHONPATH:-}:${SCRIPT_DIR}"
26
+
27
+ espeak-ng --version
28
+
29
+ # Reproducibility settings:
30
+ # - Fixed random seed SEED
31
+ # - DO_SAMPLE=0 tries to follow deterministic path (including fixed style prompt cropping start)
32
+ SEED="${SEED:-42}"
33
+ DO_SAMPLE="${DO_SAMPLE:-0}"
34
+
35
+ # Further reduce cuBLAS non-determinism (enable when needed; comment out if causes errors)
36
+ export CUBLAS_WORKSPACE_CONFIG="${CUBLAS_WORKSPACE_CONFIG:-:4096:8}"
37
+
38
+ for SONG_FILE in "${SONG_FILES[@]}"; do
39
+ SONG_NAME="$(basename "$SONG_FILE")"
40
+ INPUT_PATH="./example/zh_songs/${SONG_NAME}"
41
+ echo "=============================="
42
+ echo "Starting generation: ${SONG_NAME}"
43
+ CMD=(python inference.py
44
+ --repo-id ASLP-lab/DiffRhythm2
45
+ --output-dir ./results/zh
46
+ --input-jsonl "$INPUT_PATH"
47
+ --cfg-strength 3.0
48
+ --max-secs 285.0
49
+ --seed "$SEED"
50
+ )
51
+ if [ "$DO_SAMPLE" -eq 1 ]; then
52
+ CMD+=(--do-sample)
53
+ fi
54
+ "${CMD[@]}"
55
+ done
56
+
57
+ echo "All songs generation complete, processed ${#SONG_FILES[@]} songs."
baseline_generate/diffrhythm2/batch_inference_en.sh ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ # Navigate to script directory to ensure relative paths are consistent
5
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
6
+ cd "$SCRIPT_DIR"
7
+
8
+ SONG_DIR="$SCRIPT_DIR/example/en_songs"
9
+
10
+ if [ ! -d "$SONG_DIR" ]; then
11
+ echo "Song directory not found: $SONG_DIR"
12
+ exit 1
13
+ fi
14
+
15
+ # Collect all song_*.jsonl files
16
+ shopt -s nullglob
17
+ SONG_FILES=("$SONG_DIR"/song_*.jsonl)
18
+ shopt -u nullglob
19
+
20
+ if [ ${#SONG_FILES[@]} -eq 0 ]; then
21
+ echo "No song_*.jsonl files in song directory"
22
+ exit 0
23
+ fi
24
+
25
+ export PYTHONPATH="${PYTHONPATH:-}:${SCRIPT_DIR}"
26
+
27
+ espeak-ng --version
28
+
29
+ # Reproducibility settings:
30
+ # - Fixed random seed SEED
31
+ # - DO_SAMPLE=0 tries to follow deterministic path (including fixed style prompt cropping start)
32
+ SEED="${SEED:-42}"
33
+ DO_SAMPLE="${DO_SAMPLE:-0}"
34
+
35
+ # Further reduce cuBLAS non-determinism (enable when needed; comment out if causes errors)
36
+ export CUBLAS_WORKSPACE_CONFIG="${CUBLAS_WORKSPACE_CONFIG:-:4096:8}"
37
+
38
+ for SONG_FILE in "${SONG_FILES[@]}"; do
39
+ SONG_NAME="$(basename "$SONG_FILE")"
40
+ INPUT_PATH="./example/en_songs/${SONG_NAME}"
41
+ echo "=============================="
42
+ echo "Starting generation: ${SONG_NAME}"
43
+ CMD=(python inference.py
44
+ --repo-id ASLP-lab/DiffRhythm2
45
+ --output-dir ./results/en
46
+ --input-jsonl "$INPUT_PATH"
47
+ --cfg-strength 3.0
48
+ --max-secs 285.0
49
+ --seed "$SEED"
50
+ )
51
+ if [ "$DO_SAMPLE" -eq 1 ]; then
52
+ CMD+=(--do-sample)
53
+ fi
54
+ "${CMD[@]}"
55
+ done
56
+
57
+ echo "All songs generation complete, processed ${#SONG_FILES[@]} songs."
baseline_generate/diffrhythm2/inference.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torchaudio
17
+ import argparse
18
+ import json
19
+ import os
20
+ from tqdm import tqdm
21
+ import random
22
+ import pedalboard
23
+ import numpy as np
24
+
25
+ from muq import MuQMuLan
26
+ from diffrhythm2.cfm import CFM
27
+ from diffrhythm2.backbones.dit import DiT
28
+ from bigvgan.model import Generator
29
+ from huggingface_hub import hf_hub_download
30
+
31
+
32
+ STRUCT_INFO = {
33
+ "[start]": 500,
34
+ "[end]": 501,
35
+ "[intro]": 502,
36
+ "[verse]": 503,
37
+ "[chorus]": 504,
38
+ "[outro]": 505,
39
+ "[inst]": 506,
40
+ "[solo]": 507,
41
+ "[bridge]": 508,
42
+ "[hook]": 509,
43
+ "[break]": 510,
44
+ "[stop]": 511,
45
+ "[space]": 512
46
+ }
47
+
48
+ lrc_tokenizer = None
49
+
50
+
51
+ def set_seed(seed: int, deterministic: bool = True):
52
+ random.seed(seed)
53
+ np.random.seed(seed)
54
+ torch.manual_seed(seed)
55
+ if torch.cuda.is_available():
56
+ torch.cuda.manual_seed_all(seed)
57
+
58
+ if deterministic:
59
+ # best-effort deterministic behavior; some ops may still be nondeterministic on certain GPUs/kernels
60
+ torch.backends.cudnn.deterministic = True
61
+ torch.backends.cudnn.benchmark = False
62
+ try:
63
+ torch.use_deterministic_algorithms(True, warn_only=True)
64
+ except Exception:
65
+ pass
66
+
67
+ class CNENTokenizer():
68
+ def __init__(self):
69
+ curr_path = os.path.abspath(__file__)
70
+ vocab_path = os.path.join(os.path.dirname(curr_path), "g2p/g2p/vocab.json")
71
+ with open(vocab_path, 'r') as file:
72
+ self.phone2id:dict = json.load(file)['vocab']
73
+ self.id2phone = {v:k for (k, v) in self.phone2id.items()}
74
+ from g2p.g2p_generation import chn_eng_g2p
75
+ self.tokenizer = chn_eng_g2p
76
+ def encode(self, text):
77
+ phone, token = self.tokenizer(text)
78
+ token = [x+1 for x in token]
79
+ return token
80
+ def decode(self, token):
81
+ return "|".join([self.id2phone[x-1] for x in token])
82
+
83
+
84
+ def prepare_model(repo_id, device):
85
+ diffrhythm2_ckpt_path = hf_hub_download(
86
+ repo_id=repo_id,
87
+ filename="model.safetensors",
88
+ local_dir="./ckpt",
89
+ local_files_only=False,
90
+ )
91
+ diffrhythm2_config_path = hf_hub_download(
92
+ repo_id=repo_id,
93
+ filename="config.json",
94
+ local_dir="./ckpt",
95
+ local_files_only=False,
96
+ )
97
+ with open(diffrhythm2_config_path) as f:
98
+ model_config = json.load(f)
99
+
100
+ model_config['use_flex_attn'] = False
101
+ diffrhythm2 = CFM(
102
+ transformer=DiT(
103
+ **model_config
104
+ ),
105
+ num_channels=model_config['mel_dim'],
106
+ block_size=model_config['block_size'],
107
+ )
108
+
109
+ total_params = sum(p.numel() for p in diffrhythm2.parameters())
110
+
111
+ diffrhythm2 = diffrhythm2.to(device)
112
+ if diffrhythm2_ckpt_path.endswith('.safetensors'):
113
+ from safetensors.torch import load_file
114
+ ckpt = load_file(diffrhythm2_ckpt_path)
115
+ else:
116
+ ckpt = torch.load(diffrhythm2_ckpt_path, map_location='cpu')
117
+ diffrhythm2.load_state_dict(ckpt)
118
+ print(f"Total params: {total_params:,}")
119
+
120
+ # load Mulan
121
+ mulan = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./ckpt").to(device)
122
+
123
+ # load frontend
124
+ lrc_tokenizer = CNENTokenizer()
125
+
126
+ # load decoder
127
+ decoder_ckpt_path = hf_hub_download(
128
+ repo_id=repo_id,
129
+ filename="decoder.bin",
130
+ local_dir="./ckpt",
131
+ local_files_only=False,
132
+ )
133
+ decoder_config_path = hf_hub_download(
134
+ repo_id=repo_id,
135
+ filename="decoder.json",
136
+ local_dir="./ckpt",
137
+ local_files_only=False,
138
+ )
139
+ decoder = Generator(decoder_config_path, decoder_ckpt_path)
140
+ decoder = decoder.to(device)
141
+ return diffrhythm2, mulan, lrc_tokenizer, decoder
142
+
143
+
144
+ def parse_lyrics(lyrics: str):
145
+ lyrics_with_time = []
146
+ lyrics = lyrics.split("\n")
147
+ for line in lyrics:
148
+ struct_idx = STRUCT_INFO.get(line, None)
149
+ if struct_idx is not None:
150
+ lyrics_with_time.append([struct_idx, STRUCT_INFO['[stop]']])
151
+ else:
152
+ tokens = lrc_tokenizer.encode(line.strip())
153
+ tokens = tokens + [STRUCT_INFO['[stop]']]
154
+ lyrics_with_time.append(tokens)
155
+ return lyrics_with_time
156
+
157
+
158
+ def make_fake_stereo(audio, sampling_rate):
159
+ left_channel = audio
160
+ right_channel = audio.copy()
161
+ right_channel = right_channel * 0.8
162
+ delay_samples = int(0.01 * sampling_rate)
163
+ right_channel = np.roll(right_channel, delay_samples)
164
+ right_channel[:,:delay_samples] = 0
165
+ stereo_audio = np.concatenate([left_channel, right_channel], axis=0)
166
+
167
+ return stereo_audio
168
+
169
+
170
+ def inference(
171
+ model,
172
+ decoder,
173
+ text,
174
+ style_prompt,
175
+ duration,
176
+ output_dir,
177
+ song_name,
178
+ cfg_strength,
179
+ sample_steps=32,
180
+ process_bar=True,
181
+ fake_stereo=True,
182
+ ):
183
+ with torch.inference_mode():
184
+ latent = model.sample_block_cache(
185
+ text=text.unsqueeze(0),
186
+ duration=int(duration * 5),
187
+ style_prompt=style_prompt.unsqueeze(0),
188
+ steps=sample_steps,
189
+ cfg_strength=cfg_strength,
190
+ process_bar=process_bar,
191
+ )
192
+ latent = latent.transpose(1, 2)
193
+ audio = decoder.decode_audio(latent, overlap=5, chunk_size=20)
194
+
195
+ basename = f"{song_name}.mp3"
196
+ output_path = os.path.join(output_dir, basename)
197
+
198
+ num_channels = 1
199
+ audio = audio.float().cpu().numpy().squeeze()[None, :]
200
+ if fake_stereo:
201
+ audio = make_fake_stereo(audio, decoder.h.sampling_rate)
202
+ num_channels = 2
203
+
204
+ with pedalboard.io.AudioFile(output_path, "w", decoder.h.sampling_rate, num_channels) as f:
205
+ f.write(audio)
206
+
207
+
208
+ if __name__ == "__main__":
209
+
210
+ parser = argparse.ArgumentParser()
211
+
212
+ parser.add_argument('--repo-id', type=str, default=None)
213
+ parser.add_argument('--output-dir', type=str, default=None)
214
+ parser.add_argument('--input-jsonl', type=str, default=None)
215
+ parser.add_argument('--cfg-strength', type=float, default=2.0)
216
+ parser.add_argument('--max-secs', type=float, default=210.0)
217
+ parser.add_argument('--steps', type=int, default=16)
218
+ parser.add_argument('--fake-stereo', type=bool, default=True)
219
+ parser.add_argument('--seed', type=int, default=42)
220
+ parser.add_argument('--do-sample', action='store_true', default=False)
221
+
222
+ args = parser.parse_args()
223
+
224
+ output_dir = args.output_dir
225
+ input_jsonl = args.input_jsonl
226
+ cfg_strength = args.cfg_strength
227
+ max_secs = args.max_secs
228
+ device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
229
+ dtype = torch.float16
230
+
231
+ # reproducibility
232
+ set_seed(args.seed, deterministic=(not args.do_sample))
233
+
234
+ # load diffrhythm2
235
+ diffrhythm2, mulan, lrc_tokenizer, decoder = prepare_model(args.repo_id, device)
236
+
237
+ output_dir = args.output_dir
238
+ os.makedirs(output_dir, exist_ok=True)
239
+
240
+ with open(input_jsonl, 'r') as f:
241
+ input_info = [json.loads(i.strip()) for i in f.readlines()]
242
+
243
+ for i in tqdm(range(len(input_info))):
244
+ info = input_info[i]
245
+ song_name = info.get('song_name', f"{i:04d}")
246
+ lyrics = info.get('lyrics', None)
247
+ style_prompt = info.get('style_prompt', None)
248
+ if lyrics is None or style_prompt is None:
249
+ print(f"lyrics or style_prompt is None, skip {song_name}")
250
+ continue
251
+
252
+ # preprocess lyrics
253
+ with open(lyrics, 'r') as f:
254
+ lyrics = f.read()
255
+ lyrics_token = parse_lyrics(lyrics)
256
+ lyrics_token = torch.tensor(sum(lyrics_token, []), dtype=torch.long, device=device)
257
+
258
+ # preprocess style prompt
259
+ if os.path.isfile(style_prompt):
260
+ prompt_wav, sr = torchaudio.load(style_prompt)
261
+ prompt_wav = torchaudio.functional.resample(prompt_wav.to(device), sr, 24000)
262
+ if prompt_wav.shape[1] > 24000 * 10:
263
+ if args.do_sample:
264
+ start = random.randint(0, prompt_wav.shape[1] - 24000 * 10)
265
+ else:
266
+ start = 0
267
+ prompt_wav = prompt_wav[:, start:start+24000*10]
268
+ prompt_wav = prompt_wav.mean(dim=0, keepdim=True)
269
+ with torch.no_grad():
270
+ style_prompt_embed = mulan(wavs = prompt_wav)
271
+ else:
272
+ with torch.no_grad():
273
+ style_prompt_embed = mulan(texts = [style_prompt])
274
+ style_prompt_embed = style_prompt_embed.to(device).squeeze(0)
275
+
276
+ if device.type != 'cpu':
277
+ diffrhythm2 = diffrhythm2.half()
278
+ decoder = decoder.half()
279
+ style_prompt_embed = style_prompt_embed.half()
280
+
281
+ inference(
282
+ model=diffrhythm2,
283
+ decoder=decoder,
284
+ text=lyrics_token,
285
+ style_prompt=style_prompt_embed,
286
+ duration=max_secs,
287
+ output_dir=output_dir,
288
+ song_name=song_name,
289
+ sample_steps=args.steps,
290
+ cfg_strength=cfg_strength,
291
+ fake_stereo=args.fake_stereo,
292
+ )
293
+
294
+
baseline_generate/diffrhythm2/inference.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=$PYTHONPATH:$PWD
3
+ espeak-ng --version
4
+
5
+ python inference.py \
6
+ --repo-id ASLP-lab/DiffRhythm2 \
7
+ --output-dir ./results/test \
8
+ --input-jsonl ./example/song_1.jsonl \
9
+ --cfg-strength 3.0 \
10
+ --max-secs 285.0 \
baseline_generate/diffrhythm2/scripts/__pycache__/proce_song_enprompt.cpython-311.pyc ADDED
Binary file (6.37 kB). View file
 
baseline_generate/diffrhythm2/scripts/__pycache__/proce_song_enprompt_justtag.cpython-311.pyc ADDED
Binary file (6.5 kB). View file
 
baseline_generate/diffrhythm2/scripts/proce_song.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Process songs.jsonl to generate corresponding lrc files and jsonl files.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import re
11
+ from pathlib import Path
12
+ from typing import List
13
+
14
+ INPUT_JSONL = Path("xxx/diffrhythm2/example/final_zh_test.jsonl")
15
+ OUTPUT_SONG_DIR = Path("xxx/diffrhythm2/example/zh_songs")
16
+ OUTPUT_LRC_DIR = Path("xxx/diffrhythm2/example/zh_lrc")
17
+
18
+ TIMESTAMP_PATTERN = re.compile(r"\[\d{2}:\d{2}(?:\.\d+)?\]")
19
+ STRUCTURE_PATTERN = re.compile(r"^\[[^\]]+\]$")
20
+
21
+
22
+ def normalize_structure(tag: str) -> str:
23
+ """Convert structure tag to target format."""
24
+ tag_lower = tag.lower()
25
+ if tag_lower.startswith("verse"):
26
+ return "[verse]"
27
+ if "chorus" in tag_lower:
28
+ return "[chorus]"
29
+ if "bridge" in tag_lower:
30
+ return "[bridge]"
31
+ return f"[{tag_lower}]"
32
+
33
+
34
+ def transform_lyrics(raw_lyrics: str) -> List[str]:
35
+ """Convert lyrics to LRC line list according to requirements."""
36
+ lines = ["[start]", "[intro]"]
37
+ for raw_line in raw_lyrics.splitlines():
38
+ line = raw_line.strip()
39
+ if not line:
40
+ continue
41
+
42
+ # Process structure tags separately
43
+ if STRUCTURE_PATTERN.match(line) and not TIMESTAMP_PATTERN.match(line):
44
+ tag_content = line[1:-1].strip()
45
+ lines.append(normalize_structure(tag_content))
46
+ continue
47
+
48
+ # Remove timestamps
49
+ text = TIMESTAMP_PATTERN.sub("", line).strip()
50
+ if not text:
51
+ continue
52
+ lines.append(text)
53
+
54
+ lines.append("[end]")
55
+ return lines
56
+
57
+
58
+ def ensure_dirs() -> None:
59
+ OUTPUT_SONG_DIR.mkdir(parents=True, exist_ok=True)
60
+ OUTPUT_LRC_DIR.mkdir(parents=True, exist_ok=True)
61
+
62
+
63
+ def process_songs() -> None:
64
+ ensure_dirs()
65
+ with INPUT_JSONL.open("r", encoding="utf-8") as infile:
66
+ for idx, line in enumerate(infile, start=1):
67
+ line = line.strip()
68
+ if not line:
69
+ continue
70
+ data = json.loads(line)
71
+ description = data.get("description", "")
72
+ lyrics_raw = data.get("lyrics", "")
73
+
74
+ lrc_lines = transform_lyrics(lyrics_raw)
75
+ lrc_filename = f"song_{idx}.lrc"
76
+ lrc_path = OUTPUT_LRC_DIR / lrc_filename
77
+ lrc_path.write_text("\n".join(lrc_lines), encoding="utf-8")
78
+
79
+ song_base = f"song_{idx}"
80
+ song_filename = f"{song_base}.jsonl"
81
+ song_json_path = OUTPUT_SONG_DIR / song_filename
82
+ song_entry = {
83
+ "song_name": song_base,
84
+ "style_prompt": description,
85
+ "lyrics": f"example/zh_lrc/{lrc_filename}",
86
+ }
87
+ song_json_path.write_text(json.dumps(song_entry, ensure_ascii=False) + "\n", encoding="utf-8")
88
+ print(f"Processed song {idx}: {song_filename}")
89
+
90
+
91
+ if __name__ == "__main__":
92
+ process_songs()
baseline_generate/diffrhythm2/scripts/proce_song_enprompt.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ def extract_user_prompt(messages):
5
+ """
6
+ Extract all content from messages where role is user, and concatenate them
7
+ When concatenating each content, check if the concatenated prompt length exceeds 1600,
8
+ if so, do not concatenate and skip subsequent segments to ensure paragraph integrity
9
+
10
+ Args:
11
+ messages: List of dictionaries, each containing role and content fields
12
+
13
+ Returns:
14
+ Concatenated prompt string
15
+ """
16
+ # Collect all user message content, but need to check length limit
17
+ user_contents = []
18
+ current_length = 0 # Current concatenated length
19
+
20
+ for msg in messages:
21
+ if msg.get("role") == "user":
22
+ content = msg.get("content", "")
23
+ if content:
24
+ # Calculate total length if this content is added
25
+ # Need to consider newline: if content already exists, need to add a newline
26
+ if user_contents:
27
+ # Content already exists, need to add newline and current content
28
+ new_length = current_length + 1 + len(content) # 1 is newline length
29
+ else:
30
+ # First content, no newline needed
31
+ new_length = len(content)
32
+
33
+ # If adding this content doesn't exceed 1600, add it
34
+ if new_length <= 1600:
35
+ user_contents.append(content)
36
+ current_length = new_length
37
+ else:
38
+ # Exceeds 1600, don't add this content and stop processing subsequent segments
39
+ break
40
+
41
+ # Concatenate all content with newlines
42
+ if user_contents:
43
+ prompt = "\n".join(user_contents)
44
+ return prompt
45
+
46
+ # If no user message found, return empty string
47
+ return ""
48
+
49
+ def update_song_file(file_path, new_prompt):
50
+ """
51
+ Update style_prompt field in song file
52
+
53
+ Args:
54
+ file_path: Path to song file
55
+ new_prompt: New prompt content
56
+ """
57
+ # Read file content
58
+ with open(file_path, 'r', encoding='utf-8') as f:
59
+ lines = [line.strip() for line in f if line.strip()]
60
+
61
+ if not lines:
62
+ print(f" Warning: File {file_path} is empty, skipping")
63
+ return
64
+
65
+ # Read first JSON data
66
+ try:
67
+ data = json.loads(lines[0])
68
+ # Update style_prompt field
69
+ data['style_prompt'] = new_prompt
70
+
71
+ # Write back to file
72
+ with open(file_path, 'w', encoding='utf-8') as f:
73
+ f.write(json.dumps(data, ensure_ascii=False) + '\n')
74
+ # If there's a second empty line, keep it
75
+ if len(lines) > 1:
76
+ f.write('\n')
77
+
78
+ print(f" ✓ Updated {file_path}")
79
+ except json.JSONDecodeError as e:
80
+ print(f" Error: JSON parsing failed {file_path}: {e}")
81
+ except Exception as e:
82
+ print(f" Error: Failed to update file {file_path}: {e}")
83
+
84
+ def main():
85
+ # File paths
86
+ input_file = "xxx/diffrhythm2/scripts/test_messages.jsonl"
87
+ zh_songs_dir = "xxx/diffrhythm2/example/zh_songs"
88
+ en_songs_dir = "xxx/diffrhythm2/example/en_songs"
89
+
90
+ print(f"Reading file: {input_file}")
91
+
92
+ # Read all data
93
+ with open(input_file, 'r', encoding='utf-8') as f:
94
+ lines = [line.strip() for line in f if line.strip()]
95
+
96
+ print(f"Read {len(lines)} entries")
97
+
98
+ # Process each entry
99
+ for idx, line in enumerate(lines, 1):
100
+ try:
101
+ data = json.loads(line)
102
+ messages = data.get("messages", [])
103
+
104
+ # Extract prompt
105
+ prompt = extract_user_prompt(messages)
106
+
107
+ if not prompt:
108
+ print(f"Processing entry {idx}: No user content found, skipping")
109
+ continue
110
+
111
+ # Determine if Chinese or English
112
+ if idx <= 50:
113
+ # First 50 entries: Chinese songs
114
+ song_num = idx
115
+ target_dir = zh_songs_dir
116
+ lang = "Chinese"
117
+ else:
118
+ # Entries 51-100: English songs
119
+ song_num = idx - 50 # 51->1, 52->2, ..., 100->50
120
+ target_dir = en_songs_dir
121
+ lang = "English"
122
+
123
+ # Build file path
124
+ song_file = os.path.join(target_dir, f"song_{song_num}.jsonl")
125
+
126
+ print(f"Processing entry {idx} ({lang}, song_{song_num})...")
127
+ print(f" Prompt length: {len(prompt)} characters")
128
+
129
+ # Update file
130
+ update_song_file(song_file, prompt)
131
+
132
+ except json.JSONDecodeError as e:
133
+ print(f"JSON parsing failed for entry {idx}: {e}")
134
+ continue
135
+ except Exception as e:
136
+ print(f"Error processing entry {idx}: {e}")
137
+ import traceback
138
+ traceback.print_exc()
139
+ continue
140
+
141
+ print(f"\nProcessing complete! Processed {len(lines)} entries")
142
+
143
+ if __name__ == "__main__":
144
+ main()
145
+
baseline_generate/diffrhythm2/scripts/proce_song_enprompt_justtag.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ def extract_user_prompt(messages):
5
+ """
6
+ Extract the first user role content from messages list, extract style string
7
+
8
+ Format is usually: "Please generate a song in the following style: (style description)\n"
9
+ Only keep the style string part after the colon
10
+
11
+ Args:
12
+ messages: List of dictionaries, each containing role and content fields
13
+
14
+ Returns:
15
+ Style string
16
+ """
17
+ # Find first message with role user
18
+ for msg in messages:
19
+ if msg.get("role") == "user":
20
+ content = msg.get("content", "")
21
+ if content:
22
+ # Find position of "Please generate a song in the following style:"
23
+ style_prefix = "Please generate a song in the following style:"
24
+ style_index = content.find(style_prefix)
25
+
26
+ if style_index != -1:
27
+ # Find start position of content after colon
28
+ start_index = style_index + len(style_prefix)
29
+ # Find position of newline
30
+ newline_index = content.find("\n", start_index)
31
+
32
+ if newline_index != -1:
33
+ # Extract content from after colon to before newline
34
+ style_text = content[start_index:newline_index].strip()
35
+ else:
36
+ # If no newline, extract to end of string
37
+ style_text = content[start_index:].strip()
38
+
39
+ return style_text
40
+ else:
41
+ # If standard format not found, return empty string
42
+ return ""
43
+
44
+ # If no user message found, return empty string
45
+ return ""
46
+
47
+ def update_song_file(file_path, new_prompt):
48
+ """
49
+ Update style_prompt field in song file
50
+
51
+ Args:
52
+ file_path: Path to song file
53
+ new_prompt: New prompt content
54
+ """
55
+ # Read file content
56
+ with open(file_path, 'r', encoding='utf-8') as f:
57
+ lines = [line.strip() for line in f if line.strip()]
58
+
59
+ if not lines:
60
+ print(f" Warning: File {file_path} is empty, skipping")
61
+ return
62
+
63
+ # Read first JSON data
64
+ try:
65
+ data = json.loads(lines[0])
66
+ # Update style_prompt field
67
+ data['style_prompt'] = new_prompt
68
+
69
+ # Write back to file
70
+ with open(file_path, 'w', encoding='utf-8') as f:
71
+ f.write(json.dumps(data, ensure_ascii=False) + '\n')
72
+ # If there's a second empty line, keep it
73
+ if len(lines) > 1:
74
+ f.write('\n')
75
+
76
+ print(f" ✓ Updated {file_path}")
77
+ except json.JSONDecodeError as e:
78
+ print(f" Error: JSON parsing failed {file_path}: {e}")
79
+ except Exception as e:
80
+ print(f" Error: Failed to update file {file_path}: {e}")
81
+
82
+ def main():
83
+ # File paths
84
+ input_file = "xxx/diffrhythm2/scripts/test_messages.jsonl"
85
+ zh_songs_dir = "xxx/diffrhythm2/example/zh_songs"
86
+ en_songs_dir = "xxx/diffrhythm2/example/en_songs"
87
+
88
+ print(f"Reading file: {input_file}")
89
+
90
+ # Read all data
91
+ with open(input_file, 'r', encoding='utf-8') as f:
92
+ lines = [line.strip() for line in f if line.strip()]
93
+
94
+ print(f"Read {len(lines)} entries")
95
+
96
+ # Process each entry
97
+ for idx, line in enumerate(lines, 1):
98
+ try:
99
+ data = json.loads(line)
100
+ messages = data.get("messages", [])
101
+
102
+ # Extract prompt
103
+ prompt = extract_user_prompt(messages)
104
+
105
+ if not prompt:
106
+ print(f"Processing entry {idx}: No user content found, skipping")
107
+ continue
108
+
109
+ # Determine if Chinese or English
110
+ if idx <= 50:
111
+ # First 50 entries: Chinese songs
112
+ song_num = idx
113
+ target_dir = zh_songs_dir
114
+ lang = "Chinese"
115
+ else:
116
+ # Entries 51-100: English songs
117
+ song_num = idx - 50 # 51->1, 52->2, ..., 100->50
118
+ target_dir = en_songs_dir
119
+ lang = "English"
120
+
121
+ # Build file path
122
+ song_file = os.path.join(target_dir, f"song_{song_num}.jsonl")
123
+
124
+ print(f"Processing entry {idx} ({lang}, song_{song_num})...")
125
+ print(f" Prompt length: {len(prompt)} characters")
126
+
127
+ # Update file
128
+ update_song_file(song_file, prompt)
129
+
130
+ except json.JSONDecodeError as e:
131
+ print(f"JSON parsing failed for entry {idx}: {e}")
132
+ continue
133
+ except Exception as e:
134
+ print(f"Error processing entry {idx}: {e}")
135
+ import traceback
136
+ traceback.print_exc()
137
+ continue
138
+
139
+ print(f"\nProcessing complete! Processed {len(lines)} entries")
140
+
141
+ if __name__ == "__main__":
142
+ main()
143
+
baseline_generate/levo/__pycache__/generate.cpython-311.pyc ADDED
Binary file (37.1 kB). View file
 
baseline_generate/levo/convert.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert data to ACE-STEP acceptable format
3
+ """
4
+
5
+ import re
6
+ import json
7
+ import random
8
+ from tqdm import tqdm
9
+
10
+ random.seed(42)
11
+
12
+ def load_jsonl(path:str) -> list[dict]:
13
+ data = []
14
+ with open(path, 'r') as file:
15
+ for line in tqdm(file, desc=f"Loading {path}"):
16
+ data.append(json.loads(line))
17
+ return data
18
+
19
+ def save_jsonl(data:list, path:str):
20
+ with open(path, 'w', encoding='utf-8') as file:
21
+ for ele in tqdm(data, desc=f"Saving {path}"):
22
+ json.dump(ele, file, ensure_ascii=False)
23
+ file.write("\n")
24
+
25
+ START_STR = "Please generate a song in the following style:"
26
+ END_STR = "\nNext, I will tell you the requirements and lyrics"
27
+
28
+ def process_tag(content:str) -> str:
29
+ """Process segment label"""
30
+ # Extract label
31
+ end = content.find("[desc:")
32
+ tag = content[1:end-1]
33
+ # Lowercase & remove numbers & remove parentheses
34
+ tag = tag.lower()
35
+ tag = re.sub(r'\d+', '', tag)
36
+ tag = re.sub(r'\([^)]*\)', '', tag).strip()
37
+ if tag == "pre-chorus":
38
+ tag = "chorus"
39
+ return f"[{tag}]"
40
+
41
+ def process_lyrics(content:str) -> str:
42
+ """Process segment lyrics"""
43
+ # Extract lyrics
44
+ start = content.find("[lyrics:\n")
45
+ if start == -1:
46
+ return ""
47
+ end = content.find("][phoneme:")
48
+ lyric = content[start+len("[lyrics:\n"):end]
49
+
50
+ # Punctuation conversion
51
+ pattern = r'[,。",:;&—‘\'.\]\[()?\n-]'
52
+ lyric = re.sub(pattern, '.', lyric)
53
+ while lyric.find("..") != -1:
54
+ lyric = lyric.replace("..", ".")
55
+ if lyric.endswith('.'):
56
+ lyric = lyric[:-1]
57
+ return lyric
58
+
59
+ def random_size() -> str:
60
+ # Intro/outro length
61
+ sizes = ['short', 'medium', 'long']
62
+ return random.choice(sizes)
63
+
64
+ def process_one(messages:list[dict]):
65
+ """Process a conversation messages into input format, return gt_lyric and descriptions"""
66
+ # Overall style
67
+ style:str = messages[0]['content']
68
+ start = style.find(START_STR)
69
+ end = style.find(END_STR)
70
+ descriptions = style[start+len(START_STR):end]
71
+
72
+ # Line-by-line lyrics
73
+ start_tag = "intro-" + random_size()
74
+ end_tag = "outro-" + random_size()
75
+ gt_lyric = f"[{start_tag}] ;"
76
+ for message in messages[1:]:
77
+ if message['role'] == "assistant":
78
+ continue
79
+ content = message['content']
80
+ # Segment label
81
+ tag = process_tag(content)
82
+ # Segment lyrics
83
+ lyric = process_lyrics(content)
84
+ if lyric == "" or tag.startswith("[outro"):
85
+ gt_lyric += f" [{end_tag}]"
86
+ break
87
+ gt_lyric += f" {tag} {lyric} ;"
88
+ if not gt_lyric.endswith(f" [{end_tag}]"):
89
+ gt_lyric += f" [{end_tag}]"
90
+ return descriptions, gt_lyric
91
+
92
+ def main():
93
+ path = "xxx/SongGeneration/data/inputs/test_messages.jsonl"
94
+ dataset = load_jsonl(path)
95
+ save_path = "xxx/SongGeneration/data/inputs/lyrics.jsonl"
96
+
97
+ with open(save_path, 'w', encoding='utf-8') as file:
98
+ for id, ele in tqdm(enumerate(dataset), desc="Processing"):
99
+ messages = ele['messages']
100
+ descriptions, gt_lyric = process_one(messages)
101
+ data = {
102
+ "idx": f"test_{id}",
103
+ "descriptions": descriptions,
104
+ "gt_lyric": gt_lyric
105
+ }
106
+ json.dump(data, file, ensure_ascii=False)
107
+ file.write("\n")
108
+
109
+ if __name__ == "__main__":
110
+ main()
baseline_generate/levo/generate.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hmac import new
2
+ import sys
3
+ import os
4
+ import argparse
5
+
6
+ import time
7
+ import json
8
+ import torch
9
+ import torchaudio
10
+ import numpy as np
11
+ from omegaconf import OmegaConf
12
+ from codeclm.models import builders
13
+ import gc
14
+ from codeclm.trainer.codec_song_pl import CodecLM_PL
15
+ from codeclm.models import CodecLM
16
+ from third_party.demucs.models.pretrained import get_model_from_yaml
17
+ import re
18
+ import subprocess
19
+
20
+ auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
21
+
22
+ def get_free_gpu() -> int:
23
+ """Return the GPU ID with the least memory usage"""
24
+ cmd = "nvidia-smi --query-gpu=index,memory.free --format=csv,noheader,nounits"
25
+ result = subprocess.check_output(cmd.split()).decode().strip().split("\n")
26
+
27
+ free_list = []
28
+ for line in result:
29
+ idx, free_mem = line.split(",")
30
+ free_list.append((int(idx), int(free_mem))) # (GPU id, free memory MiB)
31
+
32
+ # Sort by remaining memory
33
+ free_list.sort(key=lambda x: x[1], reverse=True)
34
+ return free_list[0][0]
35
+
36
+ class Separator:
37
+ def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
38
+ gpu_id = get_free_gpu()
39
+ self.device = f"cuda:{gpu_id}"
40
+ print(f"Using {self.device}")
41
+
42
+ # if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
43
+ # self.device = torch.device(f"cuda:{gpu_id}")
44
+ # else:
45
+ # self.device = torch.device("cpu")
46
+
47
+ self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
48
+
49
+ def init_demucs_model(self, model_path, config_path):
50
+ model = get_model_from_yaml(config_path, model_path)
51
+ model.to(self.device)
52
+ model.eval()
53
+ return model
54
+
55
+ def load_audio(self, f):
56
+ a, fs = torchaudio.load(f)
57
+ if (fs != 48000):
58
+ a = torchaudio.functional.resample(a, fs, 48000)
59
+ if a.shape[-1] >= 48000*10:
60
+ a = a[..., :48000*10]
61
+ return a[:, 0:48000*10]
62
+
63
+ def run(self, audio_path, output_dir='tmp', ext=".flac"):
64
+ os.makedirs(output_dir, exist_ok=True)
65
+ name, _ = os.path.splitext(os.path.split(audio_path)[-1])
66
+ output_paths = []
67
+
68
+ for stem in self.demucs_model.sources:
69
+ output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
70
+ if os.path.exists(output_path):
71
+ output_paths.append(output_path)
72
+ if len(output_paths) == 1: # 4
73
+ vocal_path = output_paths[0]
74
+ else:
75
+ drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
76
+ for path in [drums_path, bass_path, other_path]:
77
+ os.remove(path)
78
+ full_audio = self.load_audio(audio_path)
79
+ vocal_audio = self.load_audio(vocal_path)
80
+ bgm_audio = full_audio - vocal_audio
81
+ return full_audio, vocal_audio, bgm_audio
82
+
83
+
84
+ def parse_args():
85
+ parser = argparse.ArgumentParser(description='Song Generation Script')
86
+
87
+ # Required parameters
88
+ parser.add_argument('--ckpt_path', type=str, required=True,
89
+ help='Path to the checkpoint directory containing config.yaml and model.pt')
90
+ parser.add_argument('--input_jsonl', type=str, required=True,
91
+ help='Path to input JSONL file containing generation tasks')
92
+ parser.add_argument('--save_dir', type=str, required=True,
93
+ help='Directory to save generated audio files and results')
94
+ # Optional parameters
95
+ parser.add_argument('--generate_type', type=str, default='mixed',
96
+ help='Type of generation: "vocal" or "bgm" or "separate" or "mixed" (default: "mixed")')
97
+ parser.add_argument('--use_flash_attn', action='store_true',
98
+ help='Whether to use flash attention (default: False)')
99
+ parser.add_argument('--low_mem', action='store_true',
100
+ help='Whether to use low memory mode (default: False)')
101
+ return parser.parse_args()
102
+
103
+ def generate(args):
104
+ torch.set_num_threads(1)
105
+ ckpt_path = args.ckpt_path
106
+ input_jsonl = args.input_jsonl
107
+ save_dir = args.save_dir
108
+ cfg_path = os.path.join(ckpt_path, 'config.yaml')
109
+ ckpt_path = os.path.join(ckpt_path, 'model.pt')
110
+ cfg = OmegaConf.load(cfg_path)
111
+ cfg.lm.use_flash_attn_2 = args.use_flash_attn
112
+ print(f"use_flash_attn: {args.use_flash_attn}")
113
+ cfg.mode = 'inference'
114
+ max_duration = cfg.max_dur
115
+ gen_type = args.generate_type
116
+
117
+
118
+ separator = Separator()
119
+ auto_prompt = torch.load('tools/new_prompt.pt')
120
+ audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
121
+ audio_tokenizer = audio_tokenizer.eval().cuda()
122
+ with open(input_jsonl, "r") as fp:
123
+ lines = fp.readlines()
124
+
125
+
126
+ new_items = []
127
+ for line in lines:
128
+ item = json.loads(line)
129
+ target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
130
+ # get prompt audio
131
+ if "prompt_audio_path" in item:
132
+ assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
133
+ assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
134
+ with torch.no_grad():
135
+ pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
136
+ item['raw_pmt_wav'] = pmt_wav
137
+ item['raw_vocal_wav'] = vocal_wav
138
+ item['raw_bgm_wav'] = bgm_wav
139
+ if pmt_wav.dim() == 2:
140
+ pmt_wav = pmt_wav[None]
141
+ if pmt_wav.dim() != 3:
142
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
143
+ pmt_wav = list(pmt_wav)
144
+ if vocal_wav.dim() == 2:
145
+ vocal_wav = vocal_wav[None]
146
+ if vocal_wav.dim() != 3:
147
+ raise ValueError("Vocal wavs should have a shape [B, C, T].")
148
+ vocal_wav = list(vocal_wav)
149
+ if bgm_wav.dim() == 2:
150
+ bgm_wav = bgm_wav[None]
151
+ if bgm_wav.dim() != 3:
152
+ raise ValueError("BGM wavs should have a shape [B, C, T].")
153
+ bgm_wav = list(bgm_wav)
154
+ if type(pmt_wav) == list:
155
+ pmt_wav = torch.stack(pmt_wav, dim=0)
156
+ if type(vocal_wav) == list:
157
+ vocal_wav = torch.stack(vocal_wav, dim=0)
158
+ if type(bgm_wav) == list:
159
+ bgm_wav = torch.stack(bgm_wav, dim=0)
160
+ pmt_wav = pmt_wav
161
+ vocal_wav = vocal_wav
162
+ bgm_wav = bgm_wav
163
+ with torch.no_grad():
164
+ pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
165
+ melody_is_wav = False
166
+ elif "auto_prompt_audio_type" in item:
167
+ assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
168
+ prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
169
+ pmt_wav = prompt_token[:,[0],:]
170
+ vocal_wav = prompt_token[:,[1],:]
171
+ bgm_wav = prompt_token[:,[2],:]
172
+ melody_is_wav = False
173
+ else:
174
+ pmt_wav = None
175
+ vocal_wav = None
176
+ bgm_wav = None
177
+ melody_is_wav = True
178
+ item['pmt_wav'] = pmt_wav
179
+ item['vocal_wav'] = vocal_wav
180
+ item['bgm_wav'] = bgm_wav
181
+ item['melody_is_wav'] = melody_is_wav
182
+ item["idx"] = f"{item['idx']}"
183
+ item["wav_path"] = target_wav_name
184
+ new_items.append(item)
185
+
186
+ del audio_tokenizer
187
+ del separator
188
+
189
+ torch.cuda.empty_cache()
190
+
191
+ if "audio_tokenizer_checkpoint_sep" in cfg.keys():
192
+ seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
193
+ else:
194
+ seperate_tokenizer = None
195
+
196
+ if seperate_tokenizer is not None:
197
+ seperate_tokenizer = seperate_tokenizer.eval().cuda()
198
+
199
+ for item in new_items:
200
+ if "prompt_audio_path" in item:
201
+ with torch.no_grad():
202
+ vocal_wav, bgm_wav = seperate_tokenizer.encode(item['vocal_wav'].cuda(), item['bgm_wav'].cuda())
203
+ item['vocal_wav'] = vocal_wav
204
+ item['bgm_wav'] = bgm_wav
205
+
206
+ torch.cuda.empty_cache()
207
+ audiolm = builders.get_lm_model(cfg)
208
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
209
+ audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
210
+ audiolm.load_state_dict(audiolm_state_dict, strict=False)
211
+ audiolm = audiolm.eval()
212
+ audiolm = audiolm.cuda().to(torch.float16)
213
+
214
+ model = CodecLM(name = "tmp",
215
+ lm = audiolm,
216
+ audiotokenizer = None,
217
+ max_duration = max_duration,
218
+ seperate_tokenizer = seperate_tokenizer,
219
+ )
220
+
221
+ cfg_coef = 1.5 #25
222
+ temp = 0.9
223
+ top_k = 50
224
+ top_p = 0.0
225
+ record_tokens = True
226
+ record_window = 50
227
+
228
+ model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
229
+ top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
230
+ os.makedirs(save_dir, exist_ok=True)
231
+ os.makedirs(save_dir + "/audios", exist_ok=True)
232
+ os.makedirs(save_dir + "/jsonl", exist_ok=True)
233
+
234
+ for item in new_items:
235
+ lyric = item["gt_lyric"]
236
+ descriptions = item["descriptions"] if "descriptions" in item else None
237
+ pmt_wav = item['pmt_wav']
238
+ vocal_wav = item['vocal_wav']
239
+ bgm_wav = item['bgm_wav']
240
+ melody_is_wav = item['melody_is_wav']
241
+ target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
242
+
243
+
244
+ generate_inp = {
245
+ 'lyrics': [lyric.replace(" ", " ")],
246
+ 'descriptions': [descriptions],
247
+ 'melody_wavs': pmt_wav,
248
+ 'vocal_wavs': vocal_wav,
249
+ 'bgm_wavs': bgm_wav,
250
+ 'melody_is_wav': melody_is_wav,
251
+ }
252
+ start_time = time.time()
253
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
254
+ with torch.no_grad():
255
+ tokens = model.generate(**generate_inp, return_tokens=True)
256
+ mid_time = time.time()
257
+
258
+ with torch.no_grad():
259
+ if 'raw_pmt_wav' in item:
260
+ if gen_type == 'separate':
261
+ wav_seperate = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='mixed')
262
+ wav_vocal = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='vocal')
263
+ wav_bgm = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='bgm')
264
+ elif gen_type == 'mixed':
265
+ wav_seperate = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type=gen_type)
266
+ else:
267
+ wav_seperate = model.generate_audio(tokens,chunked=True, gen_type=gen_type)
268
+ del item['raw_pmt_wav']
269
+ del item['raw_vocal_wav']
270
+ del item['raw_bgm_wav']
271
+ else:
272
+ if gen_type == 'separate':
273
+ wav_vocal = model.generate_audio(tokens, chunked=True, gen_type='vocal')
274
+ wav_bgm = model.generate_audio(tokens, chunked=True, gen_type='bgm')
275
+ wav_seperate = model.generate_audio(tokens, chunked=True, gen_type='mixed')
276
+ else:
277
+ wav_seperate = model.generate_audio(tokens, chunked=True, gen_type=gen_type)
278
+ del item['pmt_wav']
279
+ del item['vocal_wav']
280
+ del item['bgm_wav']
281
+ del item['melody_is_wav']
282
+ end_time = time.time()
283
+ if gen_type == 'separate':
284
+ torchaudio.save(target_wav_name.replace('.flac', '_vocal.flac'), wav_vocal[0].cpu().float(), cfg.sample_rate)
285
+ torchaudio.save(target_wav_name.replace('.flac', '_bgm.flac'), wav_bgm[0].cpu().float(), cfg.sample_rate)
286
+ torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
287
+ else:
288
+ torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
289
+
290
+ print(f"process{item['idx']}, lm cost {mid_time - start_time}s, diffusion cost {end_time - mid_time}")
291
+ item["idx"] = f"{item['idx']}"
292
+ item["wav_path"] = target_wav_name
293
+
294
+ src_jsonl_name = os.path.split(input_jsonl)[-1]
295
+ with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
296
+ for item in new_items:
297
+ fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
298
+
299
+ def generate_lowmem(args):
300
+ torch.set_num_threads(1)
301
+ ckpt_path = args.ckpt_path
302
+ input_jsonl = args.input_jsonl
303
+ save_dir = args.save_dir
304
+ cfg_path = os.path.join(ckpt_path, 'config.yaml')
305
+ ckpt_path = os.path.join(ckpt_path, 'model.pt')
306
+ cfg = OmegaConf.load(cfg_path)
307
+ cfg.lm.use_flash_attn_2 = args.use_flash_attn
308
+ print(f"use_flash_attn: {args.use_flash_attn}")
309
+ cfg.mode = 'inference'
310
+ max_duration = cfg.max_dur
311
+ gen_type = args.generate_type
312
+ chunk_size = 128
313
+ use_audio_tokenizer = False
314
+ with open(input_jsonl, "r") as fp:
315
+ lines = fp.readlines()
316
+ for line in lines:
317
+ item = json.loads(line)
318
+ if "prompt_audio_path" in item:
319
+ use_audio_tokenizer = True
320
+ break
321
+ if use_audio_tokenizer:
322
+ separator = Separator()
323
+ audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
324
+ audio_tokenizer = audio_tokenizer.eval().cuda()
325
+ auto_prompt = torch.load('tools/new_prompt.pt')
326
+ new_items = []
327
+ for line in lines:
328
+ item = json.loads(line)
329
+ target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
330
+ # get prompt audio
331
+ if "prompt_audio_path" in item:
332
+ assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
333
+ assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
334
+ with torch.no_grad():
335
+ pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
336
+ item['raw_pmt_wav'] = pmt_wav
337
+ item['raw_vocal_wav'] = vocal_wav
338
+ item['raw_bgm_wav'] = bgm_wav
339
+ if pmt_wav.dim() == 2:
340
+ pmt_wav = pmt_wav[None]
341
+ if pmt_wav.dim() != 3:
342
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
343
+ pmt_wav = list(pmt_wav)
344
+ if vocal_wav.dim() == 2:
345
+ vocal_wav = vocal_wav[None]
346
+ if vocal_wav.dim() != 3:
347
+ raise ValueError("Vocal wavs should have a shape [B, C, T].")
348
+ vocal_wav = list(vocal_wav)
349
+ if bgm_wav.dim() == 2:
350
+ bgm_wav = bgm_wav[None]
351
+ if bgm_wav.dim() != 3:
352
+ raise ValueError("BGM wavs should have a shape [B, C, T].")
353
+ bgm_wav = list(bgm_wav)
354
+ if type(pmt_wav) == list:
355
+ pmt_wav = torch.stack(pmt_wav, dim=0)
356
+ if type(vocal_wav) == list:
357
+ vocal_wav = torch.stack(vocal_wav, dim=0)
358
+ if type(bgm_wav) == list:
359
+ bgm_wav = torch.stack(bgm_wav, dim=0)
360
+ with torch.no_grad():
361
+ pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
362
+ melody_is_wav = False
363
+ elif "auto_prompt_audio_type" in item:
364
+ assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
365
+ prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
366
+ pmt_wav = prompt_token[:,[0],:]
367
+ vocal_wav = prompt_token[:,[1],:]
368
+ bgm_wav = prompt_token[:,[2],:]
369
+ melody_is_wav = False
370
+ else:
371
+ pmt_wav = None
372
+ vocal_wav = None
373
+ bgm_wav = None
374
+ melody_is_wav = True
375
+ item['pmt_wav'] = pmt_wav
376
+ item['vocal_wav'] = vocal_wav
377
+ item['bgm_wav'] = bgm_wav
378
+ item['melody_is_wav'] = melody_is_wav
379
+ item["idx"] = f"{item['idx']}"
380
+ item["wav_path"] = target_wav_name
381
+ new_items.append(item)
382
+
383
+ if use_audio_tokenizer:
384
+ del audio_tokenizer
385
+ del separator
386
+
387
+ torch.cuda.empty_cache()
388
+
389
+ if "audio_tokenizer_checkpoint_sep" in cfg.keys() and use_audio_tokenizer:
390
+ seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
391
+ else:
392
+ seperate_tokenizer = None
393
+
394
+ if seperate_tokenizer is not None:
395
+ seperate_tokenizer = seperate_tokenizer.eval().cuda()
396
+
397
+ for item in new_items:
398
+ if "prompt_audio_path" in item:
399
+ with torch.no_grad():
400
+ vocal_wav, bgm_wav = seperate_tokenizer.encode(item['vocal_wav'].cuda(), item['bgm_wav'].cuda())
401
+ item['vocal_wav'] = vocal_wav
402
+ item['bgm_wav'] = bgm_wav
403
+
404
+ if use_audio_tokenizer:
405
+ del seperate_tokenizer
406
+
407
+ torch.cuda.empty_cache()
408
+
409
+ # Define model or load pretrained model
410
+ audiolm = builders.get_lm_model(cfg)
411
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
412
+ audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
413
+ audiolm.load_state_dict(audiolm_state_dict, strict=False)
414
+ audiolm = audiolm.eval()
415
+
416
+ offload_audiolm = True if 'offload' in cfg.keys() and 'audiolm' in cfg.offload else False
417
+ if offload_audiolm:
418
+ audiolm_offload_param = OffloadParamParse.parse_config(audiolm, cfg.offload.audiolm)
419
+ audiolm_offload_param.show()
420
+ offload_profiler = OffloadProfiler(device_index=0, **(audiolm_offload_param.init_param_dict()))
421
+ offload_profiler.offload_layer(**(audiolm_offload_param.offload_layer_param_dict()))
422
+ offload_profiler.clean_cache_wrapper(**(audiolm_offload_param.clean_cache_param_dict()))
423
+ else:
424
+ audiolm = audiolm.cuda().to(torch.float16)
425
+
426
+ model = CodecLM(name = "tmp",
427
+ lm = audiolm,
428
+ audiotokenizer = None,
429
+ max_duration = max_duration,
430
+ seperate_tokenizer = None,
431
+ )
432
+
433
+ cfg_coef = 1.5 #25
434
+ temp = 0.9
435
+ top_k = 50
436
+ top_p = 0.0
437
+ record_tokens = True
438
+ record_window = 50
439
+
440
+
441
+ model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
442
+ top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
443
+ os.makedirs(save_dir, exist_ok=True)
444
+ os.makedirs(save_dir + "/audios", exist_ok=True)
445
+ os.makedirs(save_dir + "/jsonl", exist_ok=True)
446
+
447
+
448
+ for item in new_items:
449
+ lyric = item["gt_lyric"]
450
+ descriptions = item["descriptions"] if "descriptions" in item else None
451
+ pmt_wav = item['pmt_wav']
452
+ vocal_wav = item['vocal_wav']
453
+ bgm_wav = item['bgm_wav']
454
+ melody_is_wav = item['melody_is_wav']
455
+
456
+ generate_inp = {
457
+ 'lyrics': [lyric.replace(" ", " ")],
458
+ 'descriptions': [descriptions],
459
+ 'melody_wavs': pmt_wav,
460
+ 'vocal_wavs': vocal_wav,
461
+ 'bgm_wavs': bgm_wav,
462
+ 'melody_is_wav': melody_is_wav,
463
+ }
464
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
465
+ with torch.no_grad():
466
+ tokens = model.generate(**generate_inp, return_tokens=True)
467
+ if offload_audiolm:
468
+ offload_profiler.reset_empty_cache_mem_line()
469
+ item['tokens'] = tokens
470
+ if offload_audiolm:
471
+ offload_profiler.stop()
472
+ del offload_profiler
473
+ del audiolm_offload_param
474
+ del model
475
+ audiolm = audiolm.cpu()
476
+ del audiolm
477
+ del checkpoint
478
+ gc.collect()
479
+ torch.cuda.empty_cache()
480
+
481
+ seperate_tokenizer = builders.get_audio_tokenizer_model_cpu(cfg.audio_tokenizer_checkpoint_sep, cfg)
482
+ device = "cuda:0"
483
+ seperate_tokenizer.model.device = device
484
+ seperate_tokenizer.model.vae = seperate_tokenizer.model.vae.to(device)
485
+ seperate_tokenizer.model.model.device = torch.device(device)
486
+ seperate_tokenizer = seperate_tokenizer.eval()
487
+
488
+ # offload_wav_tokenizer_diffusion = True if 'offload' in cfg.keys() and 'wav_tokenizer_diffusion' in cfg.offload else False
489
+ offload_wav_tokenizer_diffusion = False
490
+ if offload_wav_tokenizer_diffusion:
491
+ sep_offload_param = OffloadParamParse.parse_config(seperate_tokenizer, cfg.offload.wav_tokenizer_diffusion)
492
+ sep_offload_param.show()
493
+ sep_offload_profiler = OffloadProfiler(device_index=0, **(sep_offload_param.init_param_dict()))
494
+ sep_offload_profiler.offload_layer(**(sep_offload_param.offload_layer_param_dict()))
495
+ sep_offload_profiler.clean_cache_wrapper(**(sep_offload_param.clean_cache_param_dict()))
496
+ else:
497
+ seperate_tokenizer.model.model = seperate_tokenizer.model.model.to(device)
498
+
499
+ model = CodecLM(name = "tmp",
500
+ lm = None,
501
+ audiotokenizer = None,
502
+ max_duration = max_duration,
503
+ seperate_tokenizer = seperate_tokenizer,
504
+ )
505
+
506
+ for item in new_items:
507
+ with torch.no_grad():
508
+ if 'raw_pmt_wav' in item:
509
+ if gen_type == 'separate':
510
+ wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type='mixed')
511
+ wav_vocal = model.generate_audio(item['tokens'],chunked=True, gen_type='vocal')
512
+ wav_bgm = model.generate_audio(item['tokens'], chunked=True, gen_type='bgm')
513
+ elif gen_type == 'mixed':
514
+ wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type=gen_type)
515
+ else:
516
+ wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
517
+ del item['raw_pmt_wav']
518
+ del item['raw_vocal_wav']
519
+ del item['raw_bgm_wav']
520
+ else:
521
+ if gen_type == 'separate':
522
+ wav_vocal = model.generate_audio(item['tokens'], chunked=True, gen_type='vocal')
523
+ wav_bgm = model.generate_audio(item['tokens'], chunked=True, gen_type='bgm')
524
+ wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type='mixed')
525
+ else:
526
+ wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
527
+ if gen_type == 'separate':
528
+ torchaudio.save(item['wav_path'].replace('.flac', '_vocal.flac'), wav_vocal[0].cpu().float(), cfg.sample_rate)
529
+ torchaudio.save(item['wav_path'].replace('.flac', '_bgm.flac'), wav_bgm[0].cpu().float(), cfg.sample_rate)
530
+ torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
531
+ else:
532
+ torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
533
+ del item['tokens']
534
+ del item['pmt_wav']
535
+ del item['vocal_wav']
536
+ del item['bgm_wav']
537
+ del item['melody_is_wav']
538
+ if offload_wav_tokenizer_diffusion:
539
+ sep_offload_profiler.reset_empty_cache_mem_line()
540
+
541
+ if offload_wav_tokenizer_diffusion:
542
+ sep_offload_profiler.stop()
543
+ torch.cuda.empty_cache()
544
+ src_jsonl_name = os.path.split(input_jsonl)[-1]
545
+ with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
546
+ for item in new_items:
547
+ fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
548
+
549
+
550
+ if __name__ == "__main__":
551
+ torch.backends.cudnn.enabled = False
552
+ OmegaConf.register_new_resolver("eval", lambda x: eval(x))
553
+ OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
554
+ OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
555
+ OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
556
+ np.random.seed(int(time.time()))
557
+ # Parse command line arguments
558
+ args = parse_args()
559
+ if torch.cuda.is_available():
560
+ device = torch.cuda.current_device()
561
+ reserved = torch.cuda.memory_reserved(device)
562
+ total = torch.cuda.get_device_properties(device).total_memory
563
+ res_mem = (total - reserved) / 1024 / 1024 / 1024
564
+ print(f"reserved memory: {res_mem}GB")
565
+
566
+ model_name = args.ckpt_path.split("/")[-1].lower().replace('-', '_')
567
+ assert model_name in ['songgeneration_base', 'songgeneration_base_new', 'songgeneration_base_full', 'songgeneration_large'], f'{model_name} is not supported, currently only songgeneration_base, songgeneration_base_new, songgeneration_base_full, songgeneration_large are supported. Please download correct files and rename the folder to the corresponding version name.'
568
+ if model_name == 'songgeneration_base' or model_name == 'songgeneration_base_new' or model_name == 'songgeneration_base_full':
569
+ if res_mem > 24 and not args.low_mem:
570
+ print("use generate")
571
+ generate(args)
572
+ else:
573
+ from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
574
+ print("use generate_lowmem")
575
+ generate_lowmem(args)
576
+ elif model_name == 'songgeneration_large':
577
+ if res_mem > 36 and not args.low_mem:
578
+ print("use generate")
579
+ generate(args)
580
+ else:
581
+ print("use generate_lowmem")
582
+ from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
583
+ generate_lowmem(args)
584
+
585
+
586
+ # elif model_name == 'songgeneration_base_full':
587
+
588
+ else:
589
+ print("CUDA is not available")
590
+ exit()
591
+
baseline_generate/mureka_o2/__pycache__/generate.cpython-311.pyc ADDED
Binary file (19.5 kB). View file
 
baseline_generate/mureka_o2/generate.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script for batch generating songs using Mureka API
4
+ Processes the first 100 songs in cleaned_data_no_desc.json file
5
+ """
6
+
7
+ import json
8
+ import os
9
+ import time
10
+ import requests
11
+ from typing import Dict, List, Optional
12
+ from pathlib import Path
13
+
14
+ # API Configuration
15
+ API_URL = "https://api.mureka.cn/v1/song/generate"
16
+ QUERY_API_URL = "https://api.mureka.cn/v1/song/query"
17
+ API_KEY_ENV = "MUREKA_API_KEY"
18
+ MODEL = "mureka-o2"
19
+
20
+ # Configuration Parameters
21
+ MAX_SONGS = 100
22
+ RETRY_TIMES = 3
23
+ RETRY_DELAY = 2 # seconds
24
+ REQUEST_DELAY = 60 # Delay between requests (seconds) - set to 60 seconds (1 minute) to avoid rate limiting
25
+ RATE_LIMIT_DELAY = 60 # Wait time when encountering 429 error (seconds)
26
+ QUERY_INTERVAL = 10 # Interval for querying task status (seconds)
27
+ MAX_QUERY_TIME = 3600 # Maximum query time (seconds), 1 hour
28
+
29
+ def load_songs(json_file: str, max_count: int = MAX_SONGS) -> List[Dict]:
30
+ """Load song data from JSON file"""
31
+ with open(json_file, 'r', encoding='utf-8') as f:
32
+ data = json.load(f)
33
+
34
+ # Only take first max_count songs
35
+ return data[:max_count]
36
+
37
+ def is_song_processed(output_file: Path) -> bool:
38
+ """
39
+ Check if song has been processed (including completed tasks)
40
+
41
+ Args:
42
+ output_file: Output file path
43
+
44
+ Returns:
45
+ True if file exists and contains valid API response
46
+ """
47
+ if not output_file.exists():
48
+ return False
49
+
50
+ try:
51
+ with open(output_file, 'r', encoding='utf-8') as f:
52
+ data = json.load(f)
53
+ # Check if contains api_response field
54
+ if 'api_response' in data and data['api_response']:
55
+ status = data['api_response'].get('status', '')
56
+ # If status is succeeded, failed, timeouted or cancelled, consider as processed
57
+ if status in ['succeeded', 'failed', 'timeouted', 'cancelled']:
58
+ return True
59
+ # If status is preparing, queued, running, streaming, reviewing, also consider as processed (task created)
60
+ if status in ['preparing', 'queued', 'running', 'streaming', 'reviewing']:
61
+ return True
62
+ except (json.JSONDecodeError, KeyError, IOError):
63
+ # File corrupted or format incorrect, consider as not processed
64
+ return False
65
+
66
+ return False
67
+
68
+ def load_processed_song(output_file: Path) -> Optional[Dict]:
69
+ """
70
+ Load processed song results from existing file
71
+
72
+ Args:
73
+ output_file: Output file path
74
+
75
+ Returns:
76
+ Processed song data, returns None if loading fails
77
+ """
78
+ try:
79
+ with open(output_file, 'r', encoding='utf-8') as f:
80
+ data = json.load(f)
81
+ return data.get('api_response')
82
+ except (json.JSONDecodeError, KeyError, IOError):
83
+ return None
84
+
85
+ def query_task_status(task_id: str, api_key: str) -> Optional[Dict]:
86
+ """
87
+ Query task status
88
+
89
+ Args:
90
+ task_id: Task ID
91
+ api_key: API key
92
+
93
+ Returns:
94
+ Task status data, returns None on failure
95
+ """
96
+ headers = {
97
+ "Authorization": f"Bearer {api_key}"
98
+ }
99
+
100
+ url = f"{QUERY_API_URL}/{task_id}"
101
+
102
+ try:
103
+ response = requests.get(url, headers=headers, timeout=30)
104
+ response.raise_for_status()
105
+ return response.json()
106
+ except requests.exceptions.RequestException as e:
107
+ print(f" Failed to query task status: {str(e)}")
108
+ return None
109
+
110
+ def wait_for_task_completion(task_id: str, api_key: str) -> Optional[Dict]:
111
+ """
112
+ Wait for task completion and return final result
113
+
114
+ Args:
115
+ task_id: Task ID
116
+ api_key: API key
117
+
118
+ Returns:
119
+ Complete data after task completion, returns None on failure
120
+ """
121
+ start_time = time.time()
122
+ last_status = None
123
+
124
+ print(f" Waiting for task completion (Task ID: {task_id})...")
125
+
126
+ while time.time() - start_time < MAX_QUERY_TIME:
127
+ result = query_task_status(task_id, api_key)
128
+
129
+ if not result:
130
+ time.sleep(QUERY_INTERVAL)
131
+ continue
132
+
133
+ status = result.get('status', '')
134
+
135
+ # If status changed, print new status
136
+ if status != last_status:
137
+ print(f" Status: {status}")
138
+ last_status = status
139
+
140
+ # Task completed (success or failure)
141
+ if status in ['succeeded', 'failed', 'timeouted', 'cancelled']:
142
+ if status == 'succeeded':
143
+ print(f" ✓ Task completed!")
144
+ if 'choices' in result and result['choices']:
145
+ print(f" Found {len(result['choices'])} generated songs")
146
+ else:
147
+ print(f" ✗ Task failed: {status}")
148
+ if 'failed_reason' in result:
149
+ print(f" Failure reason: {result['failed_reason']}")
150
+ return result
151
+
152
+ # Task still processing, continue waiting
153
+ time.sleep(QUERY_INTERVAL)
154
+
155
+ print(f" ⚠ Query timeout (exceeded {MAX_QUERY_TIME} seconds)")
156
+ return None
157
+
158
+ def generate_song(lyrics: str, prompt: str, api_key: str) -> Optional[Dict]:
159
+ """
160
+ Call API to generate a single song (serial processing, ensuring concurrency = 1)
161
+
162
+ Args:
163
+ lyrics: Lyrics content
164
+ prompt: Prompt (corresponds to description)
165
+ api_key: API key
166
+
167
+ Returns:
168
+ API response data, returns None on failure
169
+ """
170
+ headers = {
171
+ "Authorization": f"Bearer {api_key}",
172
+ "Content-Type": "application/json"
173
+ }
174
+
175
+ payload = {
176
+ "lyrics": lyrics,
177
+ "model": MODEL,
178
+ "prompt": prompt
179
+ }
180
+
181
+ for attempt in range(RETRY_TIMES):
182
+ try:
183
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=300)
184
+
185
+ # Check if it's a 429 error (rate limit)
186
+ if response.status_code == 429:
187
+ # Try to get Retry-After time from response header
188
+ retry_after = response.headers.get('Retry-After')
189
+ if retry_after:
190
+ wait_time = int(retry_after)
191
+ else:
192
+ wait_time = RATE_LIMIT_DELAY
193
+
194
+ print(f" Attempt {attempt + 1}/{RETRY_TIMES} failed: 429 Too Many Requests")
195
+ print(f" Waiting {wait_time} seconds before retry...")
196
+ if attempt < RETRY_TIMES - 1:
197
+ time.sleep(wait_time)
198
+ continue
199
+ else:
200
+ print(f" All retries failed")
201
+ return None
202
+
203
+ response.raise_for_status()
204
+ return response.json()
205
+
206
+ except requests.exceptions.HTTPError as e:
207
+ if e.response and e.response.status_code == 429:
208
+ # 429 error already handled above
209
+ continue
210
+ print(f" Attempt {attempt + 1}/{RETRY_TIMES} failed: {str(e)}")
211
+ if attempt < RETRY_TIMES - 1:
212
+ time.sleep(RETRY_DELAY)
213
+ else:
214
+ print(f" All retries failed")
215
+ return None
216
+ except requests.exceptions.RequestException as e:
217
+ print(f" Attempt {attempt + 1}/{RETRY_TIMES} failed: {str(e)}")
218
+ if attempt < RETRY_TIMES - 1:
219
+ time.sleep(RETRY_DELAY)
220
+ else:
221
+ print(f" All retries failed")
222
+ return None
223
+
224
+ return None
225
+
226
+ def main():
227
+ """Main function"""
228
+ # Check API key
229
+ api_key = os.getenv(API_KEY_ENV)
230
+ if not api_key:
231
+ print(f"Error: Please set environment variable {API_KEY_ENV}")
232
+ print(f"Example: export {API_KEY_ENV}=your_api_key")
233
+ return
234
+
235
+ # Load song data
236
+ json_file = "cleaned_data_no_desc.json"
237
+ if not os.path.exists(json_file):
238
+ print(f"Error: File not found {json_file}")
239
+ return
240
+
241
+ print(f"Loading song data from {json_file}...")
242
+ songs = load_songs(json_file, MAX_SONGS)
243
+ print(f"Loaded {len(songs)} songs")
244
+
245
+ # Create output directory
246
+ output_dir = Path("generated_songs")
247
+ output_dir.mkdir(exist_ok=True)
248
+
249
+ # List to save results
250
+ results = []
251
+
252
+ # Process each song (serial processing, ensuring concurrency = 1)
253
+ for idx, song in enumerate(songs, 1):
254
+ print(f"\n[{idx}/{len(songs)}] Processing song...")
255
+ print(f" Description: {song.get('description', 'N/A')[:50]}...")
256
+
257
+ # Check output file path
258
+ output_file = output_dir / f"song_{idx:03d}.json"
259
+
260
+ # Check if already processed
261
+ if is_song_processed(output_file):
262
+ print(f" ⊙ Already processed, skipping")
263
+ # Load processed results from file
264
+ existing_result = load_processed_song(output_file)
265
+ results.append({
266
+ "index": idx,
267
+ "status": "already_processed",
268
+ "output_file": str(output_file)
269
+ })
270
+ # Already processed songs don't need delay
271
+ continue
272
+
273
+ lyrics = song.get('lyrics', '')
274
+ description = song.get('description', '')
275
+
276
+ if not lyrics or not description:
277
+ print(f" Skipping: missing lyrics or description")
278
+ results.append({
279
+ "index": idx,
280
+ "status": "skipped",
281
+ "reason": "missing data"
282
+ })
283
+ continue
284
+
285
+ # Call API (serial execution, ensuring concurrency = 1)
286
+ result = generate_song(lyrics, description, api_key)
287
+
288
+ if result:
289
+ task_id = result.get('id')
290
+ initial_status = result.get('status', '')
291
+ print(f" ✓ Task created (ID: {task_id}, Status: {initial_status})")
292
+
293
+ # If task status is not final, wait for task completion
294
+ if initial_status not in ['succeeded', 'failed', 'timeouted', 'cancelled']:
295
+ final_result = wait_for_task_completion(task_id, api_key)
296
+ if final_result:
297
+ result = final_result
298
+ else:
299
+ # Query timeout, use initial result
300
+ print(f" ⚠ Using initial result (query timeout)")
301
+
302
+ # Save single result (including final status and choices)
303
+ with open(output_file, 'w', encoding='utf-8') as f:
304
+ json.dump({
305
+ "index": idx,
306
+ "original_data": song,
307
+ "api_response": result,
308
+ "task_id": task_id
309
+ }, f, ensure_ascii=False, indent=2)
310
+
311
+ # Check if successfully completed
312
+ final_status = result.get('status', '')
313
+ if final_status == 'succeeded':
314
+ results.append({
315
+ "index": idx,
316
+ "status": "success",
317
+ "output_file": str(output_file),
318
+ "task_id": task_id,
319
+ "has_audio": 'choices' in result and len(result.get('choices', [])) > 0
320
+ })
321
+ elif final_status in ['failed', 'timeouted', 'cancelled']:
322
+ results.append({
323
+ "index": idx,
324
+ "status": "failed",
325
+ "output_file": str(output_file),
326
+ "task_id": task_id,
327
+ "failed_reason": result.get('failed_reason', final_status)
328
+ })
329
+ else:
330
+ # Task still processing
331
+ results.append({
332
+ "index": idx,
333
+ "status": "processing",
334
+ "output_file": str(output_file),
335
+ "task_id": task_id,
336
+ "current_status": final_status
337
+ })
338
+ else:
339
+ print(f" ✗ Generation failed")
340
+ # Save failure information, including original data
341
+ error_file = output_dir / f"song_{idx:03d}_error.json"
342
+ with open(error_file, 'w', encoding='utf-8') as f:
343
+ json.dump({
344
+ "index": idx,
345
+ "original_data": song,
346
+ "error": "API call failed, generate_song returned None",
347
+ "timestamp": time.time()
348
+ }, f, ensure_ascii=False, indent=2)
349
+
350
+ results.append({
351
+ "index": idx,
352
+ "status": "failed",
353
+ "error_file": str(error_file),
354
+ "reason": "API call failed"
355
+ })
356
+
357
+ # Delay between requests to avoid rate limiting (ensuring concurrency = 1)
358
+ if idx < len(songs):
359
+ print(f" Waiting {REQUEST_DELAY} seconds before processing next song...")
360
+ time.sleep(REQUEST_DELAY)
361
+
362
+ # Save summary results
363
+ summary_file = output_dir / "summary.json"
364
+ with open(summary_file, 'w', encoding='utf-8') as f:
365
+ json.dump({
366
+ "total": len(songs),
367
+ "success": sum(1 for r in results if r.get("status") == "success"),
368
+ "processing": sum(1 for r in results if r.get("status") == "processing"),
369
+ "already_processed": sum(1 for r in results if r.get("status") == "already_processed"),
370
+ "failed": sum(1 for r in results if r.get("status") == "failed"),
371
+ "skipped": sum(1 for r in results if r.get("status") == "skipped"),
372
+ "results": results
373
+ }, f, ensure_ascii=False, indent=2)
374
+
375
+ # Print statistics
376
+ print(f"\n{'='*50}")
377
+ print(f"Processing complete!")
378
+ print(f"Total: {len(songs)} songs")
379
+ print(f"Successfully completed: {sum(1 for r in results if r.get('status') == 'success')} songs")
380
+ print(f"Processing: {sum(1 for r in results if r.get('status') == 'processing')} songs")
381
+ print(f"Already processed: {sum(1 for r in results if r.get('status') == 'already_processed')} songs")
382
+ print(f"Failed: {sum(1 for r in results if r.get('status') == 'failed')} songs")
383
+ print(f"Skipped: {sum(1 for r in results if r.get('status') == 'skipped')} songs")
384
+ print(f"Results saved in: {output_dir}/")
385
+ print(f"Summary file: {summary_file}")
386
+ print(f"\nTip: If tasks are still processing, you can rerun the script later to check status")
387
+
388
+ if __name__ == "__main__":
389
+ main()
390
+
baseline_generate/suno/__pycache__/suno_4_5.cpython-311.pyc ADDED
Binary file (41.6 kB). View file
 
baseline_generate/suno/__pycache__/suno_5.cpython-311.pyc ADDED
Binary file (41.7 kB). View file
 
baseline_generate/suno/config.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration file example
3
+ Copy this file as config.py and fill in your actual configuration
4
+ """
5
+
6
+ # ============== API Configuration ==============
7
+
8
+ # Suno API key (obtain from https://sunoapi.org)
9
+ SUNO_API_KEY = ""
10
+
11
+ # API base URL (usually no need to modify)
12
+ SUNO_API_BASE_URL = "https://api.sunoapi.org"
13
+
14
+ # ============== Generation Configuration ==============
15
+
16
+ # Default model version
17
+ DEFAULT_MODEL_VERSION = "V5" # Options: V3_5, V4, V4_5, V4_5PLUS, V5
18
+
19
+ # Whether to enable custom mode
20
+ DEFAULT_CUSTOM_MODE = True
21
+
22
+ # Whether to generate instrumental by default
23
+ DEFAULT_INSTRUMENTAL = False
24
+
25
+ # ============== Task Configuration ==============
26
+
27
+ # Maximum wait time (seconds)
28
+ MAX_WAIT_TIME = 300
29
+
30
+ # Check interval (seconds)
31
+ CHECK_INTERVAL = 10
32
+
33
+ # Retry count
34
+ MAX_RETRIES = 3
35
+
36
+ # ============== File Configuration ==============
37
+
38
+ # Music file save directory
39
+ OUTPUT_DIRECTORY = "./generated_music"
40
+
41
+ # Audio format
42
+ AUDIO_FORMAT = "mp3" # Options: mp3, wav
43
+
44
+ # ============== Batch Generation Configuration ==============
45
+
46
+ # Concurrency for batch generation
47
+ BATCH_CONCURRENCY = 5
48
+
49
+ # Batch generation delay (seconds, to avoid rate limiting)
50
+ BATCH_DELAY = 2
51
+
52
+ # ============== Logging Configuration ==============
53
+
54
+ # Log level
55
+ LOG_LEVEL = "INFO" # Options: DEBUG, INFO, WARNING, ERROR
56
+
57
+ # Log file path
58
+ LOG_FILE = "./suno_api.log"
59
+
60
+ # Whether to output to console
61
+ LOG_TO_CONSOLE = True
62
+
63
+ # ============== Webhook Configuration ==============
64
+
65
+ # Webhook callback URL (optional)
66
+ WEBHOOK_URL = None # Example: "https://your-domain.com/webhook"
67
+
68
+ # Webhook secret (for verifying callback requests)
69
+ WEBHOOK_SECRET = None
70
+
baseline_generate/suno/suno_4_5.py ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Suno API Batch Generation - V4.5 Special Edition
4
+ Supported models: V4_5 (default), V4_5PLUS, V4_5ALL
5
+ """
6
+ import json
7
+ import time
8
+ import requests
9
+ import os
10
+ import logging
11
+ import csv
12
+ from requests.adapters import HTTPAdapter
13
+ from urllib3.util.retry import Retry
14
+ from concurrent.futures import ThreadPoolExecutor, as_completed
15
+ from collections import deque
16
+ from threading import Lock, Semaphore
17
+ from tqdm import tqdm
18
+ import sys
19
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
20
+ from config import SUNO_API_KEY
21
+
22
+
23
+ # Configure logging
24
+ def setup_logging(output_dir):
25
+ log_file = os.path.join(output_dir, f"run_log_v4_5_{time.strftime('%Y%m%d_%H%M%S')}.txt")
26
+
27
+ # Create logger
28
+ logger = logging.getLogger('SunoBatchV4_5')
29
+ logger.setLevel(logging.INFO)
30
+
31
+ # Clear old handlers
32
+ if logger.hasHandlers():
33
+ logger.handlers.clear()
34
+
35
+ # File Handler
36
+ file_handler = logging.FileHandler(log_file, encoding='utf-8')
37
+ file_handler.setFormatter(logging.Formatter('%(message)s'))
38
+ logger.addHandler(file_handler)
39
+
40
+ # Console Handler
41
+ console_handler = logging.StreamHandler()
42
+ console_handler.setFormatter(logging.Formatter('%(message)s'))
43
+ logger.addHandler(console_handler)
44
+
45
+ return logger, log_file
46
+
47
+ # Global logger
48
+ logger = logging.getLogger('SunoBatchV4_5')
49
+
50
+ # Replace print with logger.info
51
+ def print_log(msg):
52
+ logger.info(msg)
53
+
54
+
55
+ class SunoAPI:
56
+ """Simplified Suno API client"""
57
+
58
+ def __init__(self, api_key):
59
+ self.api_key = api_key
60
+ self.base_url = 'https://api.sunoapi.org/api/v1'
61
+ self.headers = {
62
+ 'Authorization': f'Bearer {api_key}',
63
+ 'Content-Type': 'application/json'
64
+ }
65
+
66
+ # Configure retry strategy
67
+ self.session = requests.Session()
68
+ retry_strategy = Retry(
69
+ total=5, # Maximum retry count
70
+ backoff_factor=1, # Retry interval (1s, 2s, 4s, 8s...)
71
+ status_forcelist=[500, 502, 503, 504], # Status codes that need retry
72
+ allowed_methods=["HEAD", "GET", "POST", "OPTIONS"] # Allowed retry methods
73
+ )
74
+ adapter = HTTPAdapter(max_retries=retry_strategy)
75
+ self.session.mount("https://", adapter)
76
+ self.session.mount("http://", adapter)
77
+
78
+ def generate_music(self, prompt, model='V4_5', vocalGender=None, **options):
79
+ """Generate music"""
80
+ payload = {
81
+ 'prompt': prompt,
82
+ 'model': model,
83
+ 'callBackUrl': 'https://example.com/callback',
84
+ **options
85
+ }
86
+
87
+ if vocalGender:
88
+ payload['vocalGender'] = vocalGender
89
+
90
+ try:
91
+ response = self.session.post(
92
+ f'{self.base_url}/generate',
93
+ headers=self.headers,
94
+ json=payload,
95
+ timeout=30
96
+ )
97
+
98
+ # Check HTTP errors
99
+ response.raise_for_status()
100
+
101
+ # Try to parse JSON
102
+ try:
103
+ result = response.json()
104
+ except json.JSONDecodeError:
105
+ raise Exception(f"API returned non-JSON response: {response.text[:200]}")
106
+
107
+ if result.get('code') != 200:
108
+ raise Exception(f"Generation failed: {result.get('msg', result)}")
109
+
110
+ return result['data']['taskId']
111
+
112
+ except requests.exceptions.RequestException as e:
113
+ raise Exception(f"Request exception: {str(e)}")
114
+
115
+ def get_task_status(self, task_id):
116
+ """Get task status"""
117
+ try:
118
+ response = self.session.get(
119
+ f'{self.base_url}/generate/record-info?taskId={task_id}',
120
+ headers={'Authorization': f'Bearer {self.api_key}'},
121
+ timeout=30
122
+ )
123
+ response.raise_for_status()
124
+ return response.json().get('data', {})
125
+ except Exception as e:
126
+ # Status query failure should not crash the program, return empty dict or throw specific exception
127
+ # print_log(f"Failed to get status: {e}")
128
+ raise e
129
+
130
+ def get_timestamped_lyrics(self, task_id, audio_id):
131
+ """Get timestamped lyrics"""
132
+ payload = {
133
+ 'taskId': task_id,
134
+ 'audioId': audio_id
135
+ }
136
+
137
+ try:
138
+ response = self.session.post(
139
+ f'{self.base_url}/generate/get-timestamped-lyrics',
140
+ headers=self.headers,
141
+ json=payload,
142
+ timeout=30
143
+ )
144
+ response.raise_for_status()
145
+ return response.json()
146
+ except Exception:
147
+ return {} # Lyrics retrieval failure is non-fatal error
148
+
149
+ def wait_for_completion(self, task_id, max_wait_time=600, check_interval=5):
150
+ """Wait for task completion, return result and polling statistics"""
151
+ start_time = time.time()
152
+ poll_count = 0
153
+ total_poll_time = 0
154
+
155
+ while time.time() - start_time < max_wait_time:
156
+ try:
157
+ poll_start = time.time()
158
+ status = self.get_task_status(task_id)
159
+ poll_time = time.time() - poll_start
160
+ poll_count += 1
161
+ total_poll_time += poll_time
162
+
163
+ current_status = status.get('status')
164
+
165
+ if current_status == 'SUCCESS':
166
+ return {
167
+ 'result': status.get('response'),
168
+ 'wait_time': time.time() - start_time,
169
+ 'poll_count': poll_count,
170
+ 'avg_poll_time': total_poll_time / poll_count if poll_count > 0 else 0
171
+ }
172
+ elif current_status == 'FAILED':
173
+ raise Exception(f"Task failed: {status.get('errorMessage')}")
174
+
175
+ time.sleep(check_interval)
176
+ except Exception as e:
177
+ if time.time() - start_time >= max_wait_time:
178
+ raise
179
+ time.sleep(check_interval)
180
+
181
+ raise Exception('Task timeout')
182
+
183
+ def download_file(self, url, save_path):
184
+ """Download file to local, return download statistics"""
185
+ try:
186
+ start_time = time.time()
187
+ downloaded_bytes = 0
188
+
189
+ # Use session to download
190
+ with self.session.get(url, stream=True, timeout=60) as r:
191
+ r.raise_for_status()
192
+ with open(save_path, 'wb') as f:
193
+ for chunk in r.iter_content(chunk_size=8192):
194
+ f.write(chunk)
195
+ downloaded_bytes += len(chunk)
196
+
197
+ download_time = time.time() - start_time
198
+ return {
199
+ 'success': True,
200
+ 'bytes': downloaded_bytes,
201
+ 'time': download_time,
202
+ 'speed': downloaded_bytes / download_time if download_time > 0 else 0
203
+ }
204
+ except Exception as e:
205
+ print_log(f"Download failed {url}: {e}")
206
+ return {'success': False, 'error': str(e)}
207
+
208
+
209
+ # Result record lock
210
+ result_lock = Lock()
211
+
212
+ def save_result_record(output_dir, record):
213
+ """Save single result to CSV in real-time"""
214
+ file_path = os.path.join(output_dir, "generation_results.csv")
215
+ file_exists = os.path.isfile(file_path)
216
+
217
+ # Only record key information
218
+ row = {
219
+ 'song_id': record.get('song_id'),
220
+ 'task_id': record.get('task_id'),
221
+ 'status': 'SUCCESS' if record.get('success') else 'FAILED',
222
+ 'error': record.get('error', ''),
223
+ 'submit_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(record.get('submit_time', 0))),
224
+ 'total_time': f"{record.get('total_time', 0):.1f}",
225
+ 'tracks_count': record.get('tracks_count', 0)
226
+ }
227
+
228
+ with result_lock:
229
+ with open(file_path, 'a', newline='', encoding='utf-8') as f:
230
+ writer = csv.DictWriter(f, fieldnames=['song_id', 'task_id', 'status', 'error', 'submit_time', 'total_time', 'tracks_count'])
231
+ if not file_exists:
232
+ writer.writeheader()
233
+ writer.writerow(row)
234
+
235
+
236
+ class ImprovedRateLimiter:
237
+ """Improved rate limiter (with statistics)
238
+
239
+ Precise control: maximum 8 requests per 10 seconds
240
+ Uses sliding window algorithm to ensure no more than 8 requests in any 10-second time window
241
+ """
242
+
243
+ def __init__(self, max_requests=5, time_window=10):
244
+ self.max_requests = max_requests
245
+ self.time_window = time_window
246
+ self.request_times = deque()
247
+ self.lock = Lock()
248
+ self.semaphore = Semaphore(max_requests)
249
+
250
+ # Statistics
251
+ self.total_wait_time = 0
252
+ self.wait_count = 0
253
+ self.total_requests = 0
254
+
255
+ def acquire(self):
256
+ """Acquire request permission"""
257
+ with self.lock:
258
+ now = time.time()
259
+
260
+ # Clean expired request records
261
+ while self.request_times and now - self.request_times[0] >= self.time_window:
262
+ self.request_times.popleft()
263
+
264
+ # If limit reached, calculate wait time needed
265
+ wait_time = 0
266
+ if len(self.request_times) >= self.max_requests:
267
+ oldest_request = self.request_times[0]
268
+ wait_time = self.time_window - (now - oldest_request) + 0.05 # Add buffer
269
+
270
+ if wait_time > 0:
271
+ print_log(f" [Rate Limit] Waiting {wait_time:.2f} seconds...")
272
+ time.sleep(wait_time)
273
+
274
+ # Record wait time
275
+ self.total_wait_time += wait_time
276
+ self.wait_count += 1
277
+
278
+ # Re-clean
279
+ now = time.time()
280
+ while self.request_times and now - self.request_times[0] >= self.time_window:
281
+ self.request_times.popleft()
282
+
283
+ # Record this request time
284
+ self.request_times.append(time.time())
285
+ self.total_requests += 1
286
+
287
+ def get_current_rate(self):
288
+ """Get current rate (number of requests in last 10 seconds)"""
289
+ with self.lock:
290
+ now = time.time()
291
+ while self.request_times and now - self.request_times[0] >= self.time_window:
292
+ self.request_times.popleft()
293
+ return len(self.request_times)
294
+
295
+ def get_stats(self):
296
+ """Get statistics"""
297
+ with self.lock:
298
+ return {
299
+ 'total_requests': self.total_requests,
300
+ 'total_wait_time': self.total_wait_time,
301
+ 'wait_count': self.wait_count,
302
+ 'avg_wait_time': self.total_wait_time / self.wait_count if self.wait_count > 0 else 0
303
+ }
304
+
305
+
306
+ # Global rate limiter (5 requests per 10 seconds)
307
+ rate_limiter = ImprovedRateLimiter(max_requests=5, time_window=10)
308
+
309
+
310
+ def submit_generation_task(api, song_index, data):
311
+ """Phase 1: Submit generation task (rate limited)"""
312
+ # Use sunov4_5_000001 format
313
+ song_id = data.get("id", f"sunov4_5_{song_index:06d}")
314
+
315
+ try:
316
+ description = data.get("description", "")
317
+ lyrics = data.get("lyrics", "")
318
+ vocal_gender = data.get("vocalGender")
319
+
320
+ print_log(f"[Song {song_id}] Submitting task... (current rate: {rate_limiter.get_current_rate()}/5)")
321
+
322
+ # Record request start time
323
+ request_start = time.time()
324
+
325
+ # Rate limiting
326
+ rate_limiter.acquire()
327
+
328
+ # Submit task
329
+ submit_start = time.time()
330
+ task_id = api.generate_music(
331
+ prompt=lyrics,
332
+ style=description,
333
+ title=f"Song_{song_id}",
334
+ model='V4_5', # Explicitly specify V4.5 model
335
+ customMode=True,
336
+ instrumental=False,
337
+ vocalGender=vocal_gender
338
+ )
339
+ request_time = time.time() - submit_start
340
+
341
+ print_log(f"[Song {song_id}] ✓ Task submitted, ID: {task_id}")
342
+
343
+ return {
344
+ 'song_id': song_id,
345
+ 'song_index': song_index,
346
+ 'task_id': task_id,
347
+ 'data': data,
348
+ 'submit_time': time.time(),
349
+ 'request_time': request_time,
350
+ 'success': True
351
+ }
352
+
353
+ except Exception as e:
354
+ print_log(f"[Song {song_id}] ✗ Submission failed: {e}")
355
+ # If submission fails, also record it (even though not at download stage yet)
356
+ return {
357
+ 'song_id': song_id,
358
+ 'song_index': song_index,
359
+ 'success': False,
360
+ 'error': str(e)
361
+ }
362
+
363
+
364
+ def wait_and_download_result(api, task_info, output_dir):
365
+ """Phase 2: Wait for result and download (not rate limited)"""
366
+ if not task_info['success']:
367
+ return task_info
368
+
369
+ song_id = task_info['song_id']
370
+ song_index = task_info['song_index']
371
+ task_id = task_info['task_id']
372
+ data = task_info['data']
373
+ start_time = task_info['submit_time']
374
+
375
+ try:
376
+ original_lyrics = data.get("original_lyrics", data.get("lyrics", ""))
377
+ lyrics = data.get("lyrics", "")
378
+ description = data.get("description", "")
379
+
380
+ print_log(f"[Song {song_id}] Waiting for generation to complete...")
381
+
382
+ # Wait for completion (returns detailed statistics)
383
+ wait_result = api.wait_for_completion(task_id, max_wait_time=600, check_interval=8)
384
+ result = wait_result['result']
385
+
386
+ # Process returned result
387
+ tracks = []
388
+ if isinstance(result, dict):
389
+ if 'data' in result:
390
+ tracks = result['data']
391
+ elif 'sunoData' in result:
392
+ tracks = result['sunoData']
393
+ else:
394
+ for key, value in result.items():
395
+ if isinstance(value, list) and len(value) > 0 and 'audioUrl' in value[0]:
396
+ tracks = value
397
+ break
398
+
399
+ if not tracks:
400
+ raise Exception("Audio track data not found")
401
+
402
+ # Download phase statistics
403
+ download_start = time.time()
404
+ downloaded_files = []
405
+ total_download_bytes = 0
406
+ download_count = 0
407
+
408
+ # Process each track
409
+ for track_idx, track in enumerate(tracks):
410
+ audio_url = track.get('audioUrl') or track.get('audio_url')
411
+ audio_id = track.get('id')
412
+
413
+ base_filename = f"{song_id}_{track_idx}"
414
+ audio_path = os.path.join(output_dir, f"{base_filename}.mp3")
415
+ lyrics_path = os.path.join(output_dir, f"{base_filename}_lyrics.json")
416
+
417
+ # Download audio
418
+ if audio_url:
419
+ download_result = api.download_file(audio_url, audio_path)
420
+ if download_result['success']:
421
+ downloaded_files.append(audio_path)
422
+ total_download_bytes += download_result['bytes']
423
+ download_count += 1
424
+
425
+ # Get timestamped lyrics
426
+ timestamped_lyrics_data = None
427
+ if audio_id:
428
+ try:
429
+ lyrics_response = api.get_timestamped_lyrics(task_id, audio_id)
430
+ if lyrics_response.get('code') == 200:
431
+ timestamped_lyrics_data = lyrics_response.get('data')
432
+ except Exception as e:
433
+ print_log(f"[Song {song_id}] Track {track_idx+1}: Failed to get lyrics: {e}")
434
+
435
+ # Save lyrics and metadata
436
+ lyrics_content = {
437
+ "song_id": song_id,
438
+ "song_index": song_index,
439
+ "track_index": track_idx,
440
+ "original_lyrics": original_lyrics,
441
+ "cleaned_lyrics": lyrics,
442
+ "timestamped_lyrics": timestamped_lyrics_data,
443
+ "style": description,
444
+ "full_track_data": track
445
+ }
446
+
447
+ with open(lyrics_path, 'w', encoding='utf-8') as f:
448
+ json.dump(lyrics_content, f, ensure_ascii=False, indent=2)
449
+ downloaded_files.append(lyrics_path)
450
+
451
+ download_time = time.time() - download_start
452
+ total_time = time.time() - start_time
453
+
454
+ print_log(f"[Song {song_id}] ✓ Complete! {len(tracks)} tracks, took {total_time:.1f} seconds")
455
+
456
+ final_result = {
457
+ 'song_id': song_id,
458
+ 'song_index': song_index,
459
+ 'task_id': task_id,
460
+ 'success': True,
461
+ 'tracks_count': len(tracks),
462
+ 'files': downloaded_files,
463
+ 'total_time': total_time,
464
+ 'submit_time': start_time,
465
+ 'wait_time': wait_result['wait_time'],
466
+ 'poll_count': wait_result['poll_count'],
467
+ 'avg_poll_time': wait_result['avg_poll_time'],
468
+ 'download_time': download_time,
469
+ 'download_bytes': total_download_bytes,
470
+ 'download_count': download_count,
471
+ 'avg_download_speed': total_download_bytes / download_time if download_time > 0 else 0
472
+ }
473
+
474
+ # Save result in real-time
475
+ save_result_record(output_dir, final_result)
476
+ return final_result
477
+
478
+ except Exception as e:
479
+ total_time = time.time() - start_time
480
+ print_log(f"[Song {song_id}] ✗ Processing failed: {e} (took {total_time:.1f} seconds)")
481
+
482
+ error_result = {
483
+ 'song_id': song_id,
484
+ 'song_index': song_index,
485
+ 'task_id': task_id,
486
+ 'success': False,
487
+ 'error': str(e),
488
+ 'total_time': total_time,
489
+ 'submit_time': start_time
490
+ }
491
+
492
+ # Save result in real-time
493
+ save_result_record(output_dir, error_result)
494
+ return error_result
495
+
496
+
497
+ def format_bytes(bytes_size):
498
+ """Format byte size"""
499
+ for unit in ['B', 'KB', 'MB', 'GB']:
500
+ if bytes_size < 1024.0:
501
+ return f"{bytes_size:.2f} {unit}"
502
+ bytes_size /= 1024.0
503
+ return f"{bytes_size:.2f} TB"
504
+
505
+
506
+ def format_speed(bytes_per_sec):
507
+ """Format speed"""
508
+ return f"{format_bytes(bytes_per_sec)}/s"
509
+
510
+
511
+ def main():
512
+ """Main program - two-phase concurrent processing"""
513
+ input_file = "cleaned_data_truncated.json"
514
+ output_dir = "sunov4_5_truncated"
515
+ # Create output directory
516
+ os.makedirs(output_dir, exist_ok=True)
517
+
518
+ # Initialize logging
519
+ global logger
520
+ logger, log_file = setup_logging(output_dir)
521
+
522
+ print_log("=" * 70)
523
+ print_log("Suno API Batch Generation - V4.5 Special Edition")
524
+ print_log("Strategy: Fast submission (5 requests/10s) + Parallel waiting + Detailed performance analysis")
525
+ print_log(f"Log file: {log_file}")
526
+ print_log("=" * 70)
527
+
528
+ # Read input file
529
+ try:
530
+ all_data = []
531
+ if input_file.endswith('.jsonl'):
532
+ try:
533
+ with open(input_file, 'r', encoding='utf-8') as f:
534
+ # Try reading first line to determine format
535
+ first_line = f.readline().strip()
536
+ if first_line.startswith('['):
537
+ # Looks like regular JSON array
538
+ f.seek(0)
539
+ all_data = json.load(f)
540
+ else:
541
+ # Try reading line by line
542
+ f.seek(0)
543
+ for i, line in enumerate(f):
544
+ line = line.strip()
545
+ if line:
546
+ all_data.append(json.loads(line))
547
+ except json.JSONDecodeError:
548
+ # If above parsing fails, try one final read as regular JSON
549
+ print_log(f"Note: Failed to parse {input_file} as JSONL format, trying as regular JSON...")
550
+ with open(input_file, 'r', encoding='utf-8') as f:
551
+ all_data = json.load(f)
552
+ else:
553
+ with open(input_file, 'r', encoding='utf-8') as f:
554
+ all_data = json.load(f)
555
+
556
+ except FileNotFoundError:
557
+ print_log(f"File {input_file} not found.")
558
+ return
559
+ except json.JSONDecodeError as e:
560
+ print_log(f"JSON parsing error: {e}")
561
+ return
562
+
563
+ # Initialize API
564
+ api = SunoAPI(SUNO_API_KEY)
565
+
566
+ print_log(f"\nPreparing to generate {len(all_data)} songs...")
567
+ print_log(f"Start time: {time.strftime('%H:%M:%S')}\n")
568
+
569
+ overall_start_time = time.time()
570
+
571
+ # ===== Phase 1: Batch Submission =====
572
+ print_log("\n" + "=" * 70)
573
+ print_log("Phase 1: Batch Submission")
574
+ print_log("=" * 70 + "\n")
575
+
576
+ submit_start_time = time.time()
577
+ submitted_tasks = []
578
+ total_request_time = 0
579
+
580
+ # Adjust rate limit: maximum 5 requests per 10 seconds
581
+ rate_limiter.max_requests = 5
582
+ rate_limiter.time_window = 10
583
+ rate_limiter.request_times.clear()
584
+ print_log(f"Rate limit: {rate_limiter.max_requests} requests / {rate_limiter.time_window} seconds")
585
+
586
+ # Only submit tasks that need to run
587
+ tasks_to_run = []
588
+ for i, data in enumerate(all_data, 1):
589
+ tasks_to_run.append((i, data))
590
+
591
+ print_log(f"Number of tasks to submit: {len(tasks_to_run)}")
592
+
593
+ # Use thread pool for submission
594
+ # Submission concurrency is controlled by rate_limiter, can be set to 5
595
+ with ThreadPoolExecutor(max_workers=5) as executor:
596
+ submit_futures = {
597
+ executor.submit(submit_generation_task, api, idx, data): idx
598
+ for idx, data in tasks_to_run
599
+ }
600
+
601
+ with tqdm(total=len(tasks_to_run), desc="Submitting tasks", unit="song") as pbar:
602
+ for future in as_completed(submit_futures):
603
+ result = future.result()
604
+ submitted_tasks.append(result)
605
+ if result.get('success') and 'request_time' in result:
606
+ total_request_time += result['request_time']
607
+ pbar.update(1)
608
+
609
+ submit_phase_time = time.time() - submit_start_time
610
+ success_submits = sum(1 for t in submitted_tasks if t['success'])
611
+
612
+ # Get rate limit statistics
613
+ rate_limit_stats = rate_limiter.get_stats()
614
+
615
+ print_log(f"\nSubmission phase complete: {success_submits}/{len(tasks_to_run)} successful")
616
+ print_log(f" Total time: {submit_phase_time:.1f} seconds")
617
+ print_log(f" Actual request time: {total_request_time:.2f} seconds")
618
+ print_log(f" Rate limit waiting: {rate_limit_stats['total_wait_time']:.2f} seconds ({rate_limit_stats['wait_count']} times)")
619
+ if rate_limit_stats['wait_count'] > 0:
620
+ print_log(f" Average wait time: {rate_limit_stats['avg_wait_time']:.2f} seconds/time")
621
+
622
+ # ===== Phase 2: Parallel Waiting and Download =====
623
+ print_log("\n" + "=" * 70)
624
+ print_log("Phase 2: Wait for Generation and Download")
625
+ print_log("=" * 70 + "\n")
626
+
627
+ wait_start_time = time.time()
628
+ final_results = []
629
+
630
+ # Use more threads for parallel waiting (not rate limited)
631
+ with ThreadPoolExecutor(max_workers=20) as executor:
632
+ download_futures = {
633
+ executor.submit(wait_and_download_result, api, task, output_dir): task
634
+ for task in submitted_tasks if task['success']
635
+ }
636
+
637
+ # Add failed submission tasks to results
638
+ for task in submitted_tasks:
639
+ if not task['success']:
640
+ final_results.append(task)
641
+
642
+ with tqdm(total=len(download_futures), desc="Downloading results", unit="song") as pbar:
643
+ for future in as_completed(download_futures):
644
+ result = future.result()
645
+ final_results.append(result)
646
+ pbar.update(1)
647
+
648
+ wait_phase_time = time.time() - wait_start_time
649
+
650
+ # ===== Detailed Statistics and Report =====
651
+ overall_time = time.time() - overall_start_time
652
+
653
+ print_log("\n" + "=" * 70)
654
+ print_log("Batch Generation Complete - Detailed Performance Report")
655
+ print_log("=" * 70)
656
+
657
+ success_count = sum(1 for r in final_results if r.get('success'))
658
+ fail_count = len(final_results) - success_count
659
+ total_tracks = sum(r.get('tracks_count', 0) for r in final_results if r.get('success'))
660
+
661
+ successful_results = [r for r in final_results if r.get('success')]
662
+
663
+ # Basic Statistics
664
+ print_log(f"\n[Basic Statistics]")
665
+ print_log(f" Total songs: {len(all_data)}")
666
+ print_log(f" Successful: {success_count}")
667
+ print_log(f" Failed: {fail_count}")
668
+ print_log(f" Total tracks: {total_tracks}")
669
+ if success_count > 0:
670
+ avg_tracks = total_tracks / success_count
671
+ print_log(f" Average tracks per song: {avg_tracks:.2f}")
672
+
673
+ # Time Statistics
674
+ print_log(f"\n[Time Statistics]")
675
+ print_log(f" ├── Submission phase: {submit_phase_time:.1f} seconds")
676
+ print_log(f" │ ├── Actual request time: {total_request_time:.2f} seconds")
677
+ print_log(f" │ └── Rate limit waiting: {rate_limit_stats['total_wait_time']:.2f} seconds")
678
+ print_log(f" ├── Generation waiting phase: {wait_phase_time:.1f} seconds")
679
+
680
+ if successful_results:
681
+ wait_times = [r.get('wait_time', 0) for r in successful_results if 'wait_time' in r]
682
+ download_times = [r.get('download_time', 0) for r in successful_results if 'download_time' in r]
683
+
684
+ if wait_times:
685
+ avg_wait = sum(wait_times) / len(wait_times)
686
+ min_wait = min(wait_times)
687
+ max_wait = max(wait_times)
688
+ print_log(f" │ ├── Average wait time: {avg_wait:.1f} seconds/song")
689
+ print_log(f" │ ├── Fastest: {min_wait:.1f} seconds")
690
+ print_log(f" │ └── Slowest: {max_wait:.1f} seconds")
691
+
692
+ if download_times:
693
+ total_download_time = sum(download_times)
694
+ avg_download = total_download_time / len(download_times)
695
+ print_log(f" ├── Download phase: {total_download_time:.1f} seconds")
696
+ print_log(f" │ └── Average download time: {avg_download:.2f} seconds/song")
697
+
698
+ print_log(f" └── Total time: {overall_time:.1f} seconds ({overall_time/60:.1f} minutes)")
699
+
700
+ # Single Song Generation Statistics
701
+ if successful_results:
702
+ total_times = [r.get('total_time', 0) for r in successful_results if 'total_time' in r]
703
+ if total_times:
704
+ print_log(f"\n[Single Song Generation Statistics]")
705
+ avg_time = sum(total_times) / len(total_times)
706
+ min_time = min(total_times)
707
+ max_time = max(total_times)
708
+ print_log(f" Average total time per song: {avg_time:.1f} seconds")
709
+ print_log(f" Fastest generation: {min_time:.1f} seconds")
710
+ print_log(f" Slowest generation: {max_time:.1f} seconds")
711
+
712
+ # Download Statistics
713
+ total_download_bytes = sum(r.get('download_bytes', 0) for r in successful_results)
714
+ total_download_count = sum(r.get('download_count', 0) for r in successful_results)
715
+
716
+ if total_download_bytes > 0:
717
+ print_log(f"\n[Download Statistics]")
718
+ print_log(f" Total download: {format_bytes(total_download_bytes)}")
719
+ print_log(f" Number of files: {total_download_count}")
720
+ print_log(f" Average file size: {format_bytes(total_download_bytes / total_download_count)}")
721
+
722
+ download_speeds = [r.get('avg_download_speed', 0) for r in successful_results if r.get('avg_download_speed', 0) > 0]
723
+ if download_speeds:
724
+ avg_speed = sum(download_speeds) / len(download_speeds)
725
+ print_log(f" Average download speed: {format_speed(avg_speed)}")
726
+
727
+ # Polling Statistics
728
+ poll_counts = [r.get('poll_count', 0) for r in successful_results if 'poll_count' in r]
729
+ if poll_counts:
730
+ total_polls = sum(poll_counts)
731
+ avg_polls = total_polls / len(poll_counts)
732
+ print_log(f"\n[Polling Statistics]")
733
+ print_log(f" Total polling count: {total_polls}")
734
+ print_log(f" Average polling per song: {avg_polls:.1f}")
735
+
736
+ # Efficiency Analysis
737
+ print_log(f"\n[Efficiency Analysis]")
738
+ if success_count > 0:
739
+ throughput = success_count / (overall_time / 60)
740
+ print_log(f" Actual throughput: {throughput:.2f} songs/minute")
741
+
742
+ # Theoretical fastest time (assuming no rate limit)
743
+ if wait_times:
744
+ ideal_time = submit_phase_time - rate_limit_stats['total_wait_time'] + max(wait_times)
745
+ efficiency = (ideal_time / overall_time) * 100
746
+ print_log(f" Theoretical fastest time: {ideal_time:.1f} seconds")
747
+ print_log(f" Concurrency efficiency: {efficiency:.1f}%")
748
+
749
+ # Show failed songs
750
+ if fail_count > 0:
751
+ print_log("\n" + "=" * 70)
752
+ print_log("Failed Songs List")
753
+ print_log("=" * 70)
754
+ for r in sorted(final_results, key=lambda x: x.get('song_index', 0)):
755
+ if not r.get('success'):
756
+ song_id = r.get('song_id', r.get('song_index', 'Unknown'))
757
+ print_log(f" [{song_id}] {r.get('error', 'Unknown error')}")
758
+
759
+ print_log("\n" + "=" * 70)
760
+ print_log(f"All files saved to: {os.path.abspath(output_dir)}")
761
+ print_log("=" * 70)
762
+
763
+
764
+ if __name__ == '__main__':
765
+ main()
766
+
baseline_generate/suno/suno_5.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Suno API Batch Generation - V5 Version (5 requests per 10 seconds)
4
+ Changes:
5
+ 1. Rate control: 5 requests within 10 seconds
6
+ 2. ID format: sunov5_000001
7
+ """
8
+ import json
9
+ import time
10
+ import requests
11
+ import os
12
+ import logging
13
+ import csv
14
+ from requests.adapters import HTTPAdapter
15
+ from urllib3.util.retry import Retry
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
+ from collections import deque
18
+ from threading import Lock, Semaphore
19
+ from tqdm import tqdm
20
+ import sys
21
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
22
+ from config import SUNO_API_KEY
23
+
24
+
25
+ # Configure logging
26
+ def setup_logging(output_dir):
27
+ log_file = os.path.join(output_dir, f"run_log_{time.strftime('%Y%m%d_%H%M%S')}.txt")
28
+
29
+ # Create logger
30
+ logger = logging.getLogger('SunoBatchV5')
31
+ logger.setLevel(logging.INFO)
32
+
33
+ # Clear old handlers
34
+ if logger.hasHandlers():
35
+ logger.handlers.clear()
36
+
37
+ # File Handler
38
+ file_handler = logging.FileHandler(log_file, encoding='utf-8')
39
+ file_handler.setFormatter(logging.Formatter('%(message)s'))
40
+ logger.addHandler(file_handler)
41
+
42
+ # Console Handler
43
+ console_handler = logging.StreamHandler()
44
+ console_handler.setFormatter(logging.Formatter('%(message)s'))
45
+ logger.addHandler(console_handler)
46
+
47
+ return logger, log_file
48
+
49
+ # Global logger
50
+ logger = logging.getLogger('SunoBatchV5')
51
+
52
+ # Replace print with logger.info
53
+ def print_log(msg):
54
+ logger.info(msg)
55
+
56
+
57
+ class SunoAPI:
58
+ """Simplified Suno API client"""
59
+
60
+ def __init__(self, api_key):
61
+ self.api_key = api_key
62
+ self.base_url = 'https://api.sunoapi.org/api/v1'
63
+ self.headers = {
64
+ 'Authorization': f'Bearer {api_key}',
65
+ 'Content-Type': 'application/json'
66
+ }
67
+
68
+ # Configure retry strategy
69
+ self.session = requests.Session()
70
+ retry_strategy = Retry(
71
+ total=5, # Maximum retry count
72
+ backoff_factor=1, # Retry interval (1s, 2s, 4s, 8s...)
73
+ status_forcelist=[500, 502, 503, 504], # Status codes that need retry
74
+ allowed_methods=["HEAD", "GET", "POST", "OPTIONS"] # Allowed retry methods
75
+ )
76
+ adapter = HTTPAdapter(max_retries=retry_strategy)
77
+ self.session.mount("https://", adapter)
78
+ self.session.mount("http://", adapter)
79
+
80
+ def generate_music(self, prompt, model='V5', vocalGender=None, **options):
81
+ """Generate music"""
82
+ payload = {
83
+ 'prompt': prompt,
84
+ 'model': model,
85
+ 'callBackUrl': 'https://example.com/callback',
86
+ **options
87
+ }
88
+
89
+ if vocalGender:
90
+ payload['vocalGender'] = vocalGender
91
+
92
+ try:
93
+ response = self.session.post(
94
+ f'{self.base_url}/generate',
95
+ headers=self.headers,
96
+ json=payload,
97
+ timeout=30
98
+ )
99
+
100
+ # Check HTTP errors
101
+ response.raise_for_status()
102
+
103
+ # Try to parse JSON
104
+ try:
105
+ result = response.json()
106
+ except json.JSONDecodeError:
107
+ raise Exception(f"API returned non-JSON response: {response.text[:200]}")
108
+
109
+ if result.get('code') != 200:
110
+ raise Exception(f"Generation failed: {result.get('msg', result)}")
111
+
112
+ return result['data']['taskId']
113
+
114
+ except requests.exceptions.RequestException as e:
115
+ raise Exception(f"Request exception: {str(e)}")
116
+
117
+ def get_task_status(self, task_id):
118
+ """Get task status"""
119
+ try:
120
+ response = self.session.get(
121
+ f'{self.base_url}/generate/record-info?taskId={task_id}',
122
+ headers={'Authorization': f'Bearer {self.api_key}'},
123
+ timeout=30
124
+ )
125
+ response.raise_for_status()
126
+ return response.json().get('data', {})
127
+ except Exception as e:
128
+ # Status query failure should not crash the program, return empty dict or throw specific exception
129
+ # print_log(f"Failed to get status: {e}")
130
+ raise e
131
+
132
+ def get_timestamped_lyrics(self, task_id, audio_id):
133
+ """Get timestamped lyrics"""
134
+ payload = {
135
+ 'taskId': task_id,
136
+ 'audioId': audio_id
137
+ }
138
+
139
+ try:
140
+ response = self.session.post(
141
+ f'{self.base_url}/generate/get-timestamped-lyrics',
142
+ headers=self.headers,
143
+ json=payload,
144
+ timeout=30
145
+ )
146
+ response.raise_for_status()
147
+ return response.json()
148
+ except Exception:
149
+ return {} # Lyrics retrieval failure is non-fatal error
150
+
151
+ def wait_for_completion(self, task_id, max_wait_time=600, check_interval=5):
152
+ """Wait for task completion, return result and polling statistics"""
153
+ start_time = time.time()
154
+ poll_count = 0
155
+ total_poll_time = 0
156
+
157
+ while time.time() - start_time < max_wait_time:
158
+ try:
159
+ poll_start = time.time()
160
+ status = self.get_task_status(task_id)
161
+ poll_time = time.time() - poll_start
162
+ poll_count += 1
163
+ total_poll_time += poll_time
164
+
165
+ current_status = status.get('status')
166
+
167
+ if current_status == 'SUCCESS':
168
+ return {
169
+ 'result': status.get('response'),
170
+ 'wait_time': time.time() - start_time,
171
+ 'poll_count': poll_count,
172
+ 'avg_poll_time': total_poll_time / poll_count if poll_count > 0 else 0
173
+ }
174
+ elif current_status == 'FAILED':
175
+ raise Exception(f"Task failed: {status.get('errorMessage')}")
176
+
177
+ time.sleep(check_interval)
178
+ except Exception as e:
179
+ if time.time() - start_time >= max_wait_time:
180
+ raise
181
+ time.sleep(check_interval)
182
+
183
+ raise Exception('Task timeout')
184
+
185
+ def download_file(self, url, save_path):
186
+ """Download file to local, return download statistics"""
187
+ try:
188
+ start_time = time.time()
189
+ downloaded_bytes = 0
190
+
191
+ # Use session to download
192
+ with self.session.get(url, stream=True, timeout=60) as r:
193
+ r.raise_for_status()
194
+ with open(save_path, 'wb') as f:
195
+ for chunk in r.iter_content(chunk_size=8192):
196
+ f.write(chunk)
197
+ downloaded_bytes += len(chunk)
198
+
199
+ download_time = time.time() - start_time
200
+ return {
201
+ 'success': True,
202
+ 'bytes': downloaded_bytes,
203
+ 'time': download_time,
204
+ 'speed': downloaded_bytes / download_time if download_time > 0 else 0
205
+ }
206
+ except Exception as e:
207
+ print_log(f"Download failed {url}: {e}")
208
+ return {'success': False, 'error': str(e)}
209
+
210
+
211
+ # Result record lock
212
+ result_lock = Lock()
213
+
214
+ def save_result_record(output_dir, record):
215
+ """Save single result to CSV in real-time"""
216
+ file_path = os.path.join(output_dir, "generation_results.csv")
217
+ file_exists = os.path.isfile(file_path)
218
+
219
+ # Only record key information
220
+ row = {
221
+ 'song_id': record.get('song_id'),
222
+ 'task_id': record.get('task_id'),
223
+ 'status': 'SUCCESS' if record.get('success') else 'FAILED',
224
+ 'error': record.get('error', ''),
225
+ 'submit_time': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(record.get('submit_time', 0))),
226
+ 'total_time': f"{record.get('total_time', 0):.1f}",
227
+ 'tracks_count': record.get('tracks_count', 0)
228
+ }
229
+
230
+ with result_lock:
231
+ with open(file_path, 'a', newline='', encoding='utf-8') as f:
232
+ writer = csv.DictWriter(f, fieldnames=['song_id', 'task_id', 'status', 'error', 'submit_time', 'total_time', 'tracks_count'])
233
+ if not file_exists:
234
+ writer.writeheader()
235
+ writer.writerow(row)
236
+
237
+
238
+ class ImprovedRateLimiter:
239
+ """Improved rate limiter (with statistics)
240
+
241
+ Precise control: maximum 8 requests per 10 seconds
242
+ Uses sliding window algorithm to ensure no more than 8 requests in any 10-second time window
243
+ """
244
+
245
+ def __init__(self, max_requests=5, time_window=10):
246
+ self.max_requests = max_requests
247
+ self.time_window = time_window
248
+ self.request_times = deque()
249
+ self.lock = Lock()
250
+ self.semaphore = Semaphore(max_requests)
251
+
252
+ # Statistics
253
+ self.total_wait_time = 0
254
+ self.wait_count = 0
255
+ self.total_requests = 0
256
+
257
+ def acquire(self):
258
+ """Acquire request permission"""
259
+ with self.lock:
260
+ now = time.time()
261
+
262
+ # Clean expired request records
263
+ while self.request_times and now - self.request_times[0] >= self.time_window:
264
+ self.request_times.popleft()
265
+
266
+ # If limit reached, calculate wait time needed
267
+ wait_time = 0
268
+ if len(self.request_times) >= self.max_requests:
269
+ oldest_request = self.request_times[0]
270
+ wait_time = self.time_window - (now - oldest_request) + 0.05 # Add buffer
271
+
272
+ if wait_time > 0:
273
+ print_log(f" [Rate Limit] Waiting {wait_time:.2f} seconds...")
274
+ time.sleep(wait_time)
275
+
276
+ # Record wait time
277
+ self.total_wait_time += wait_time
278
+ self.wait_count += 1
279
+
280
+ # Re-clean
281
+ now = time.time()
282
+ while self.request_times and now - self.request_times[0] >= self.time_window:
283
+ self.request_times.popleft()
284
+
285
+ # Record this request time
286
+ self.request_times.append(time.time())
287
+ self.total_requests += 1
288
+
289
+ def get_current_rate(self):
290
+ """Get current rate (number of requests in last 10 seconds)"""
291
+ with self.lock:
292
+ now = time.time()
293
+ while self.request_times and now - self.request_times[0] >= self.time_window:
294
+ self.request_times.popleft()
295
+ return len(self.request_times)
296
+
297
+ def get_stats(self):
298
+ """Get statistics"""
299
+ with self.lock:
300
+ return {
301
+ 'total_requests': self.total_requests,
302
+ 'total_wait_time': self.total_wait_time,
303
+ 'wait_count': self.wait_count,
304
+ 'avg_wait_time': self.total_wait_time / self.wait_count if self.wait_count > 0 else 0
305
+ }
306
+
307
+
308
+ # Global rate limiter (5 requests per 10 seconds)
309
+ rate_limiter = ImprovedRateLimiter(max_requests=5, time_window=10)
310
+
311
+
312
+ def submit_generation_task(api, song_index, data):
313
+ """Phase 1: Submit generation task (rate limited)"""
314
+ # Use sunov5_000001 format
315
+ song_id = data.get("id", f"sunov5_{song_index:06d}")
316
+
317
+ try:
318
+ description = data.get("description", "")
319
+ lyrics = data.get("lyrics", "")
320
+ vocal_gender = data.get("vocalGender")
321
+
322
+ print_log(f"[Song {song_id}] Submitting task... (current rate: {rate_limiter.get_current_rate()}/5)")
323
+
324
+ # Record request start time
325
+ request_start = time.time()
326
+
327
+ # Rate limiting
328
+ rate_limiter.acquire()
329
+
330
+ # Submit task
331
+ submit_start = time.time()
332
+ task_id = api.generate_music(
333
+ prompt=lyrics,
334
+ style=description,
335
+ title=f"Song_{song_id}",
336
+ model='V5',
337
+ customMode=True,
338
+ instrumental=False,
339
+ vocalGender=vocal_gender
340
+ )
341
+ request_time = time.time() - submit_start
342
+
343
+ print_log(f"[Song {song_id}] ✓ Task submitted, ID: {task_id}")
344
+
345
+ return {
346
+ 'song_id': song_id,
347
+ 'song_index': song_index,
348
+ 'task_id': task_id,
349
+ 'data': data,
350
+ 'submit_time': time.time(),
351
+ 'request_time': request_time,
352
+ 'success': True
353
+ }
354
+
355
+ except Exception as e:
356
+ print_log(f"[Song {song_id}] ✗ Submission failed: {e}")
357
+ # If submission fails, also record it (even though not at download stage yet)
358
+ return {
359
+ 'song_id': song_id,
360
+ 'song_index': song_index,
361
+ 'success': False,
362
+ 'error': str(e)
363
+ }
364
+
365
+
366
+ def wait_and_download_result(api, task_info, output_dir):
367
+ """Phase 2: Wait for result and download (not rate limited)"""
368
+ if not task_info['success']:
369
+ return task_info
370
+
371
+ song_id = task_info['song_id']
372
+ song_index = task_info['song_index']
373
+ task_id = task_info['task_id']
374
+ data = task_info['data']
375
+ start_time = task_info['submit_time']
376
+
377
+ try:
378
+ original_lyrics = data.get("original_lyrics", data.get("lyrics", ""))
379
+ lyrics = data.get("lyrics", "")
380
+ description = data.get("description", "")
381
+
382
+ print_log(f"[Song {song_id}] Waiting for generation to complete...")
383
+
384
+ # Wait for completion (returns detailed statistics)
385
+ wait_result = api.wait_for_completion(task_id, max_wait_time=600, check_interval=8)
386
+ result = wait_result['result']
387
+
388
+ # Process returned result
389
+ tracks = []
390
+ if isinstance(result, dict):
391
+ if 'data' in result:
392
+ tracks = result['data']
393
+ elif 'sunoData' in result:
394
+ tracks = result['sunoData']
395
+ else:
396
+ for key, value in result.items():
397
+ if isinstance(value, list) and len(value) > 0 and 'audioUrl' in value[0]:
398
+ tracks = value
399
+ break
400
+
401
+ if not tracks:
402
+ raise Exception("Audio track data not found")
403
+
404
+ # Download phase statistics
405
+ download_start = time.time()
406
+ downloaded_files = []
407
+ total_download_bytes = 0
408
+ download_count = 0
409
+
410
+ # Process each track
411
+ for track_idx, track in enumerate(tracks):
412
+ audio_url = track.get('audioUrl') or track.get('audio_url')
413
+ audio_id = track.get('id')
414
+
415
+ base_filename = f"{song_id}_{track_idx}"
416
+ audio_path = os.path.join(output_dir, f"{base_filename}.mp3")
417
+ lyrics_path = os.path.join(output_dir, f"{base_filename}_lyrics.json")
418
+
419
+ # Download audio
420
+ if audio_url:
421
+ download_result = api.download_file(audio_url, audio_path)
422
+ if download_result['success']:
423
+ downloaded_files.append(audio_path)
424
+ total_download_bytes += download_result['bytes']
425
+ download_count += 1
426
+
427
+ # Get timestamped lyrics
428
+ timestamped_lyrics_data = None
429
+ if audio_id:
430
+ try:
431
+ lyrics_response = api.get_timestamped_lyrics(task_id, audio_id)
432
+ if lyrics_response.get('code') == 200:
433
+ timestamped_lyrics_data = lyrics_response.get('data')
434
+ except Exception as e:
435
+ print_log(f"[Song {song_id}] Track {track_idx+1}: Failed to get lyrics: {e}")
436
+
437
+ # Save lyrics and metadata
438
+ lyrics_content = {
439
+ "song_id": song_id,
440
+ "song_index": song_index,
441
+ "track_index": track_idx,
442
+ "original_lyrics": original_lyrics,
443
+ "cleaned_lyrics": lyrics,
444
+ "timestamped_lyrics": timestamped_lyrics_data,
445
+ "style": description,
446
+ "full_track_data": track
447
+ }
448
+
449
+ with open(lyrics_path, 'w', encoding='utf-8') as f:
450
+ json.dump(lyrics_content, f, ensure_ascii=False, indent=2)
451
+ downloaded_files.append(lyrics_path)
452
+
453
+ download_time = time.time() - download_start
454
+ total_time = time.time() - start_time
455
+
456
+ print_log(f"[Song {song_id}] ✓ Complete! {len(tracks)} tracks, took {total_time:.1f} seconds")
457
+
458
+ final_result = {
459
+ 'song_id': song_id,
460
+ 'song_index': song_index,
461
+ 'task_id': task_id,
462
+ 'success': True,
463
+ 'tracks_count': len(tracks),
464
+ 'files': downloaded_files,
465
+ 'total_time': total_time,
466
+ 'submit_time': start_time,
467
+ 'wait_time': wait_result['wait_time'],
468
+ 'poll_count': wait_result['poll_count'],
469
+ 'avg_poll_time': wait_result['avg_poll_time'],
470
+ 'download_time': download_time,
471
+ 'download_bytes': total_download_bytes,
472
+ 'download_count': download_count,
473
+ 'avg_download_speed': total_download_bytes / download_time if download_time > 0 else 0
474
+ }
475
+
476
+ # Save result in real-time
477
+ save_result_record(output_dir, final_result)
478
+ return final_result
479
+
480
+ except Exception as e:
481
+ total_time = time.time() - start_time
482
+ print_log(f"[Song {song_id}] ✗ Processing failed: {e} (took {total_time:.1f} seconds)")
483
+
484
+ error_result = {
485
+ 'song_id': song_id,
486
+ 'song_index': song_index,
487
+ 'task_id': task_id,
488
+ 'success': False,
489
+ 'error': str(e),
490
+ 'total_time': total_time,
491
+ 'submit_time': start_time
492
+ }
493
+
494
+ # Save result in real-time
495
+ save_result_record(output_dir, error_result)
496
+ return error_result
497
+
498
+
499
+ def format_bytes(bytes_size):
500
+ """Format byte size"""
501
+ for unit in ['B', 'KB', 'MB', 'GB']:
502
+ if bytes_size < 1024.0:
503
+ return f"{bytes_size:.2f} {unit}"
504
+ bytes_size /= 1024.0
505
+ return f"{bytes_size:.2f} TB"
506
+
507
+
508
+ def format_speed(bytes_per_sec):
509
+ """Format speed"""
510
+ return f"{format_bytes(bytes_per_sec)}/s"
511
+
512
+
513
+ def main():
514
+ """Main program - two-phase concurrent processing"""
515
+ input_file = "cleaned_data_truncated.json"
516
+ output_dir = "sunov5_truncated"
517
+ # Create output directory
518
+ os.makedirs(output_dir, exist_ok=True)
519
+
520
+ # Initialize logging
521
+ global logger
522
+ logger, log_file = setup_logging(output_dir)
523
+
524
+ print_log("=" * 70)
525
+ print_log("Suno API Batch Generation - V5 Special Edition")
526
+ print_log("Strategy: Fast submission (5 requests/10s) + Parallel waiting + Detailed performance analysis")
527
+ print_log(f"Log file: {log_file}")
528
+ print_log("=" * 70)
529
+
530
+ # Read input file
531
+ try:
532
+ all_data = []
533
+ if input_file.endswith('.jsonl'):
534
+ try:
535
+ with open(input_file, 'r', encoding='utf-8') as f:
536
+ # Try reading first line to determine format
537
+ first_line = f.readline().strip()
538
+ if first_line.startswith('['):
539
+ # Looks like regular JSON array
540
+ f.seek(0)
541
+ all_data = json.load(f)
542
+ else:
543
+ # Try reading line by line
544
+ f.seek(0)
545
+ for i, line in enumerate(f):
546
+ line = line.strip()
547
+ if line:
548
+ all_data.append(json.loads(line))
549
+ except json.JSONDecodeError:
550
+ # If above parsing fails, try one final read as regular JSON
551
+ print_log(f"Note: Failed to parse {input_file} as JSONL format, trying as regular JSON...")
552
+ with open(input_file, 'r', encoding='utf-8') as f:
553
+ all_data = json.load(f)
554
+ else:
555
+ with open(input_file, 'r', encoding='utf-8') as f:
556
+ all_data = json.load(f)
557
+
558
+ except FileNotFoundError:
559
+ print_log(f"File {input_file} not found.")
560
+ return
561
+ except json.JSONDecodeError as e:
562
+ print_log(f"JSON parsing error: {e}")
563
+ return
564
+
565
+ # Initialize API
566
+ api = SunoAPI(SUNO_API_KEY)
567
+
568
+ print_log(f"\nPreparing to generate {len(all_data)} songs...")
569
+ print_log(f"Start time: {time.strftime('%H:%M:%S')}\n")
570
+
571
+ overall_start_time = time.time()
572
+
573
+ # ===== Phase 1: Batch Submission =====
574
+ print_log("\n" + "=" * 70)
575
+ print_log("Phase 1: Batch Submission")
576
+ print_log("=" * 70 + "\n")
577
+
578
+ submit_start_time = time.time()
579
+ submitted_tasks = []
580
+ total_request_time = 0
581
+
582
+ # Adjust rate limit: maximum 5 requests per 10 seconds
583
+ rate_limiter.max_requests = 5
584
+ rate_limiter.time_window = 10
585
+ rate_limiter.request_times.clear()
586
+ print_log(f"Rate limit: {rate_limiter.max_requests} requests / {rate_limiter.time_window} seconds")
587
+
588
+ # Only submit tasks that need to run
589
+ tasks_to_run = []
590
+ for i, data in enumerate(all_data, 1):
591
+ tasks_to_run.append((i, data))
592
+
593
+ print_log(f"Number of tasks to submit: {len(tasks_to_run)}")
594
+
595
+ # Use thread pool for submission
596
+ # Submission concurrency is controlled by rate_limiter, can be set to 5
597
+ with ThreadPoolExecutor(max_workers=5) as executor:
598
+ submit_futures = {
599
+ executor.submit(submit_generation_task, api, idx, data): idx
600
+ for idx, data in tasks_to_run
601
+ }
602
+
603
+ with tqdm(total=len(tasks_to_run), desc="Submitting tasks", unit="song") as pbar:
604
+ for future in as_completed(submit_futures):
605
+ result = future.result()
606
+ submitted_tasks.append(result)
607
+ if result.get('success') and 'request_time' in result:
608
+ total_request_time += result['request_time']
609
+ pbar.update(1)
610
+
611
+ submit_phase_time = time.time() - submit_start_time
612
+ success_submits = sum(1 for t in submitted_tasks if t['success'])
613
+
614
+ # Get rate limit statistics
615
+ rate_limit_stats = rate_limiter.get_stats()
616
+
617
+ print_log(f"\nSubmission phase complete: {success_submits}/{len(tasks_to_run)} successful")
618
+ print_log(f" Total time: {submit_phase_time:.1f} seconds")
619
+ print_log(f" Actual request time: {total_request_time:.2f} seconds")
620
+ print_log(f" Rate limit waiting: {rate_limit_stats['total_wait_time']:.2f} seconds ({rate_limit_stats['wait_count']} times)")
621
+ if rate_limit_stats['wait_count'] > 0:
622
+ print_log(f" Average wait time: {rate_limit_stats['avg_wait_time']:.2f} seconds/time")
623
+
624
+ # ===== Phase 2: Parallel Waiting and Download =====
625
+ print_log("\n" + "=" * 70)
626
+ print_log("Phase 2: Wait for Generation and Download")
627
+ print_log("=" * 70 + "\n")
628
+
629
+ wait_start_time = time.time()
630
+ final_results = []
631
+
632
+ # Use more threads for parallel waiting (not rate limited)
633
+ with ThreadPoolExecutor(max_workers=20) as executor:
634
+ download_futures = {
635
+ executor.submit(wait_and_download_result, api, task, output_dir): task
636
+ for task in submitted_tasks if task['success']
637
+ }
638
+
639
+ # Add failed submission tasks to results
640
+ for task in submitted_tasks:
641
+ if not task['success']:
642
+ final_results.append(task)
643
+
644
+ with tqdm(total=len(download_futures), desc="Downloading results", unit="song") as pbar:
645
+ for future in as_completed(download_futures):
646
+ result = future.result()
647
+ final_results.append(result)
648
+ pbar.update(1)
649
+
650
+ wait_phase_time = time.time() - wait_start_time
651
+
652
+ # ===== Detailed Statistics and Report =====
653
+ overall_time = time.time() - overall_start_time
654
+
655
+ print_log("\n" + "=" * 70)
656
+ print_log("Batch Generation Complete - Detailed Performance Report")
657
+ print_log("=" * 70)
658
+
659
+ success_count = sum(1 for r in final_results if r.get('success'))
660
+ fail_count = len(final_results) - success_count
661
+ total_tracks = sum(r.get('tracks_count', 0) for r in final_results if r.get('success'))
662
+
663
+ successful_results = [r for r in final_results if r.get('success')]
664
+
665
+ # Basic Statistics
666
+ print_log(f"\n[Basic Statistics]")
667
+ print_log(f" Total songs: {len(all_data)}")
668
+ print_log(f" Successful: {success_count}")
669
+ print_log(f" Failed: {fail_count}")
670
+ print_log(f" Total tracks: {total_tracks}")
671
+ if success_count > 0:
672
+ avg_tracks = total_tracks / success_count
673
+ print_log(f" Average tracks per song: {avg_tracks:.2f}")
674
+
675
+ # Time Statistics
676
+ print_log(f"\n[Time Statistics]")
677
+ print_log(f" ├── Submission phase: {submit_phase_time:.1f} seconds")
678
+ print_log(f" │ ├── Actual request time: {total_request_time:.2f} seconds")
679
+ print_log(f" │ └── Rate limit waiting: {rate_limit_stats['total_wait_time']:.2f} seconds")
680
+ print_log(f" ├── Generation waiting phase: {wait_phase_time:.1f} seconds")
681
+
682
+ if successful_results:
683
+ wait_times = [r.get('wait_time', 0) for r in successful_results if 'wait_time' in r]
684
+ download_times = [r.get('download_time', 0) for r in successful_results if 'download_time' in r]
685
+
686
+ if wait_times:
687
+ avg_wait = sum(wait_times) / len(wait_times)
688
+ min_wait = min(wait_times)
689
+ max_wait = max(wait_times)
690
+ print_log(f" │ ├── Average wait time: {avg_wait:.1f} seconds/song")
691
+ print_log(f" │ ├── Fastest: {min_wait:.1f} seconds")
692
+ print_log(f" │ └── Slowest: {max_wait:.1f} seconds")
693
+
694
+ if download_times:
695
+ total_download_time = sum(download_times)
696
+ avg_download = total_download_time / len(download_times)
697
+ print_log(f" ├── Download phase: {total_download_time:.1f} seconds")
698
+ print_log(f" │ └── Average download time: {avg_download:.2f} seconds/song")
699
+
700
+ print_log(f" └── Total time: {overall_time:.1f} seconds ({overall_time/60:.1f} minutes)")
701
+
702
+ # Single Song Generation Statistics
703
+ if successful_results:
704
+ total_times = [r.get('total_time', 0) for r in successful_results if 'total_time' in r]
705
+ if total_times:
706
+ print_log(f"\n[Single Song Generation Statistics]")
707
+ avg_time = sum(total_times) / len(total_times)
708
+ min_time = min(total_times)
709
+ max_time = max(total_times)
710
+ print_log(f" Average total time per song: {avg_time:.1f} seconds")
711
+ print_log(f" Fastest generation: {min_time:.1f} seconds")
712
+ print_log(f" Slowest generation: {max_time:.1f} seconds")
713
+
714
+ # Download Statistics
715
+ total_download_bytes = sum(r.get('download_bytes', 0) for r in successful_results)
716
+ total_download_count = sum(r.get('download_count', 0) for r in successful_results)
717
+
718
+ if total_download_bytes > 0:
719
+ print_log(f"\n[Download Statistics]")
720
+ print_log(f" Total download: {format_bytes(total_download_bytes)}")
721
+ print_log(f" Number of files: {total_download_count}")
722
+ print_log(f" Average file size: {format_bytes(total_download_bytes / total_download_count)}")
723
+
724
+ download_speeds = [r.get('avg_download_speed', 0) for r in successful_results if r.get('avg_download_speed', 0) > 0]
725
+ if download_speeds:
726
+ avg_speed = sum(download_speeds) / len(download_speeds)
727
+ print_log(f" Average download speed: {format_speed(avg_speed)}")
728
+
729
+ # Polling Statistics
730
+ poll_counts = [r.get('poll_count', 0) for r in successful_results if 'poll_count' in r]
731
+ if poll_counts:
732
+ total_polls = sum(poll_counts)
733
+ avg_polls = total_polls / len(poll_counts)
734
+ print_log(f"\n[Polling Statistics]")
735
+ print_log(f" Total polling count: {total_polls}")
736
+ print_log(f" Average polling per song: {avg_polls:.1f}")
737
+
738
+ # Efficiency Analysis
739
+ print_log(f"\n[Efficiency Analysis]")
740
+ if success_count > 0:
741
+ throughput = success_count / (overall_time / 60)
742
+ print_log(f" Actual throughput: {throughput:.2f} songs/minute")
743
+
744
+ # Theoretical fastest time (assuming no rate limit)
745
+ if wait_times:
746
+ ideal_time = submit_phase_time - rate_limit_stats['total_wait_time'] + max(wait_times)
747
+ efficiency = (ideal_time / overall_time) * 100
748
+ print_log(f" Theoretical fastest time: {ideal_time:.1f} seconds")
749
+ print_log(f" Concurrency efficiency: {efficiency:.1f}%")
750
+
751
+ # Show failed songs
752
+ if fail_count > 0:
753
+ print_log("\n" + "=" * 70)
754
+ print_log("Failed Songs List")
755
+ print_log("=" * 70)
756
+ for r in sorted(final_results, key=lambda x: x.get('song_index', 0)):
757
+ if not r.get('success'):
758
+ song_id = r.get('song_id', r.get('song_index', 'Unknown'))
759
+ print_log(f" [{song_id}] {r.get('error', 'Unknown error')}")
760
+
761
+ print_log("\n" + "=" * 70)
762
+ print_log(f"All files saved to: {os.path.abspath(output_dir)}")
763
+ print_log("=" * 70)
764
+
765
+
766
+ if __name__ == '__main__':
767
+ main()
768
+
baseline_generate/yue/__pycache__/infer_batch.cpython-311.pyc ADDED
Binary file (58.8 kB). View file
 
baseline_generate/yue/batch.sh ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Batch music generation example script
4
+ # Requires yue source repository, main code change is infer_batch.py replacing infer.py
5
+
6
+ # Get absolute path of script directory (to avoid filesystem mount issues)
7
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd 2>/dev/null || echo "xxx/YuE/inference")"
8
+
9
+ # Change to script directory (if possible, otherwise use absolute path)
10
+ cd "$SCRIPT_DIR" 2>/dev/null || true
11
+
12
+
13
+ # Set HuggingFace mirror
14
+ export HF_ENDPOINT=https://hf-mirror.com
15
+
16
+ # Set PyTorch CUDA memory management optimization
17
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
18
+
19
+ export CUDA_VISIBLE_DEVICES=1
20
+
21
+ # Set JSONL file path
22
+ JSONL_PATH=""
23
+
24
+ # Set output directory
25
+ OUTPUT_DIR=""
26
+
27
+ # Set processing range (optional)
28
+ # Example: only process first 5 songs
29
+ START_IDX=0
30
+ END_IDX=-1
31
+
32
+ # Set generation parameters
33
+ MAX_NEW_TOKENS=3500
34
+ REPETITION_PENALTY=1.1
35
+ RUN_N_SEGMENTS=24
36
+ STAGE2_BATCH_SIZE=16
37
+ CUDA_IDX=0
38
+ SEED=42
39
+ NO_SAMPLE=0
40
+
41
+ # Run batch generation (using absolute path)
42
+ python "$SCRIPT_DIR/infer_batch.py" \
43
+ --jsonl_path "$JSONL_PATH" \
44
+ --output_dir "$OUTPUT_DIR" \
45
+ --start_idx $START_IDX \
46
+ --end_idx $END_IDX \
47
+ --max_new_tokens $MAX_NEW_TOKENS \
48
+ --repetition_penalty $REPETITION_PENALTY \
49
+ --run_n_segments $RUN_N_SEGMENTS \
50
+ --stage2_batch_size $STAGE2_BATCH_SIZE \
51
+ --cuda_idx $CUDA_IDX \
52
+ --seed $SEED \
53
+ --rescale \
54
+ $( [ "$NO_SAMPLE" -eq 1 ] && echo "--no_sample" )
55
+
baseline_generate/yue/codecmanipulator.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import einops
4
+
5
+
6
+ class CodecManipulator(object):
7
+ r"""
8
+ **mm tokenizer v0.1**
9
+ see codeclm/hf/mm_tokenizer_v0.1_hf/id2vocab.json
10
+
11
+ text tokens:
12
+ llama tokenizer 0~31999
13
+
14
+ special tokens: "32000": "<EOD>", "32001": "<SOA>", "32002": "<EOA>", "32003": "<SOI>", "32004": "<EOI>", "32005": "<SOV>", "32006": "<EOV>", "32007": "<s_local>", "32008": "<e_local>", "32009": "<s_global>", "32010": "<e_global>", "32011": "<semantic>", "32012": "<acoustic>", "32013": "<low_level>", "32014": "<dac_16k>", "32015": "<dac_44k>", "32016": "<xcodec>", "32017": "<placeholder>", "32018": "<semantic_mert>", "32019": "<semantic_hubert>", "32020": "<visual>", "32021": "<semanticodec>"
15
+
16
+ mm tokens:
17
+ dac_16k: 4 codebook, 1024 vocab, 32022 - 36117
18
+ dac_44k: 9 codebook, 1024 vocab, 36118 - 45333
19
+ xcodec: 12 codebook, 1024 vocab, 45334 - 57621
20
+ semantic mert: 1024, 57622 - 58645
21
+ semantic hubert: 512, 58646 - 59157
22
+ visual: 64000, not included in v0.1
23
+ semanticodec 100tps 16384: semantic=16384, 59158 - 75541, acoustic=8192, 75542 - 83733
24
+ """
25
+ def __init__(self, codec_type, quantizer_begin=None, n_quantizer=None, teacher_forcing=False, data_feature="codec"):
26
+ self.codec_type = codec_type
27
+ self.mm_v0_2_cfg = {
28
+ "dac16k": {"codebook_size": 1024, "num_codebooks": 4, "global_offset": 32022, "sep": ["<dac_16k>"], "fps": 50},
29
+ "dac44k": {"codebook_size": 1024, "num_codebooks": 9, "global_offset": 36118, "sep": ["<dac_44k>"]},
30
+ "xcodec": {"codebook_size": 1024, "num_codebooks": 12, "global_offset": 45334, "sep": ["<xcodec>"], "fps": 50},
31
+ "mert": {"codebook_size": 1024, "global_offset": 57622, "sep": ["<semantic_mert>"]},
32
+ "hubert": {"codebook_size": 512, "global_offset": 58646, "sep": ["<semantic_hubert>"]},
33
+ "semantic/s": {"codebook_size": 16384, "num_codebooks": 1, "global_offset": 59158, "sep": ["<semanticodec>", "<semantic>"]},
34
+ "semantic/a": {"codebook_size": 8192, "num_codebooks": 1, "global_offset": 75542, "sep": ["<semanticodec>", "<acoustic>"]},
35
+ "semanticodec": {"codebook_size": [16384, 8192], "num_codebooks": 2, "global_offset": 59158, "sep": ["<semanticodec>"], "fps": 50},
36
+ "special_tokens": {
37
+ '<EOD>': 32000, '<SOA>': 32001, '<EOA>': 32002, '<SOI>': 32003, '<EOI>': 32004, '<SOV>': 32005, '<EOV>': 32006, '<s_local>': 32007, '<e_local>': 32008, '<s_global>': 32009, '<e_global>': 32010, '<semantic>': 32011, '<acoustic>': 32012, '<stage_1>': 32013, '<dac_16k>': 32014, '<dac_44k>': 32015, '<xcodec>': 32016, '<stage_2>': 32017, '<semantic_mert>': 32018, '<semantic_hubert>': 32019, '<visual>': 32020, '<semanticodec>': 32021
38
+ },
39
+ "metadata": {
40
+ "len": 83734,
41
+ "text_range": [0, 31999],
42
+ "special_range": [32000, 32021],
43
+ "mm_range": [32022, 83733]
44
+ },
45
+ "codec_range": {
46
+ "dac16k": [32022, 36117],
47
+ "dac44k": [36118, 45333],
48
+ "xcodec": [45334, 57621],
49
+ # "hifi16k": [53526, 57621],
50
+ "mert": [57622, 58645],
51
+ "hubert": [58646, 59157],
52
+ "semantic/s": [59158, 75541],
53
+ "semantic/a": [75542, 83733],
54
+ "semanticodec": [59158, 83733]
55
+ }
56
+ }
57
+ self.sep = self.mm_v0_2_cfg[self.codec_type]["sep"]
58
+ self.sep_ids = [self.mm_v0_2_cfg["special_tokens"][s] for s in self.sep]
59
+ self.codebook_size = self.mm_v0_2_cfg[self.codec_type]["codebook_size"]
60
+ self.num_codebooks = self.mm_v0_2_cfg[self.codec_type]["num_codebooks"]
61
+ self.global_offset = self.mm_v0_2_cfg[self.codec_type]["global_offset"]
62
+ self.fps = self.mm_v0_2_cfg[self.codec_type]["fps"] if "fps" in self.mm_v0_2_cfg[self.codec_type] else None
63
+
64
+ self.quantizer_begin = quantizer_begin if quantizer_begin is not None else 0
65
+ self.n_quantizer = n_quantizer if n_quantizer is not None else self.num_codebooks
66
+ self.teacher_forcing = teacher_forcing
67
+ self.data_feature = data_feature
68
+
69
+
70
+ def offset_tok_ids(self, x, global_offset=0, codebook_size=2048, num_codebooks=4):
71
+ """
72
+ x: (K, T)
73
+ """
74
+ if isinstance(codebook_size, int):
75
+ assert x.max() < codebook_size, f"max(x)={x.max()}, codebook_size={codebook_size}"
76
+ elif isinstance(codebook_size, list):
77
+ for i, cs in enumerate(codebook_size):
78
+ assert x[i].max() < cs, f"max(x)={x[i].max()}, codebook_size={cs}, layer_id={i}"
79
+ else:
80
+ raise ValueError(f"codebook_size={codebook_size}")
81
+ assert x.min() >= 0, f"min(x)={x.min()}"
82
+ assert x.shape[0] == num_codebooks or x.shape[0] == self.n_quantizer, \
83
+ f"x.shape[0]={x.shape[0]}, num_codebooks={num_codebooks}, n_quantizer={self.n_quantizer}"
84
+
85
+ _x = x.copy()
86
+ _x = _x.astype(np.uint32)
87
+ cum_offset = 0
88
+ quantizer_begin = self.quantizer_begin
89
+ quantizer_end = quantizer_begin+self.n_quantizer
90
+ for k in range(self.quantizer_begin, quantizer_end): # k: quantizer_begin to quantizer_end - 1
91
+ if isinstance(codebook_size, int):
92
+ _x[k] += global_offset + k * codebook_size
93
+ elif isinstance(codebook_size, list):
94
+ _x[k] += global_offset + cum_offset
95
+ cum_offset += codebook_size[k]
96
+ else:
97
+ raise ValueError(f"codebook_size={codebook_size}")
98
+ return _x[quantizer_begin:quantizer_end]
99
+
100
+ def unoffset_tok_ids(self, x, global_offset=0, codebook_size=2048, num_codebooks=4):
101
+ """
102
+ x: (K, T)
103
+ """
104
+ if isinstance(codebook_size, int):
105
+ assert x.max() < global_offset + codebook_size * num_codebooks, f"max(x)={x.max()}, codebook_size={codebook_size}"
106
+ elif isinstance(codebook_size, list):
107
+ assert x.max() < global_offset + sum(codebook_size), f"max(x)={x.max()}, codebook_size={codebook_size}"
108
+ assert x.min() >= global_offset, f"min(x)={x.min()}, global_offset={global_offset}"
109
+ assert x.shape[0] == num_codebooks or x.shape[0] == self.n_quantizer, \
110
+ f"x.shape[0]={x.shape[0]}, num_codebooks={num_codebooks}, n_quantizer={self.n_quantizer}"
111
+
112
+ _x = x.copy()
113
+ _x = _x.astype(np.uint32)
114
+ cum_offset = 0
115
+ quantizer_begin = self.quantizer_begin
116
+ quantizer_end = quantizer_begin+self.n_quantizer
117
+ for k in range(quantizer_begin, quantizer_end):
118
+ if isinstance(codebook_size, int):
119
+ _x[k-quantizer_begin] -= global_offset + k * codebook_size
120
+ elif isinstance(codebook_size, list):
121
+ _x[k-quantizer_begin] -= global_offset + cum_offset
122
+ cum_offset += codebook_size[k]
123
+ else:
124
+ raise ValueError(f"codebook_size={codebook_size}")
125
+ return _x
126
+
127
+ def flatten(self, x):
128
+ if len(x.shape) > 2:
129
+ x = x.squeeze()
130
+ assert x.shape[0] == self.num_codebooks or x.shape[0] == self.n_quantizer, \
131
+ f"x.shape[0]={x.shape[0]}, num_codebooks={self.num_codebooks}, n_quantizer={self.n_quantizer}"
132
+ return einops.rearrange(x, 'K T -> (T K)')
133
+
134
+ def unflatten(self, x, n_quantizer=None):
135
+ if x.ndim > 1 and x.shape[0] == 1:
136
+ x = x.squeeze(0)
137
+ assert len(x.shape) == 1
138
+ assert x.shape[0] % self.num_codebooks == 0 or x.shape[0] % self.n_quantizer == 0, \
139
+ f"x.shape[0]={x.shape[0]}, num_codebooks={self.num_codebooks}, n_quantizer={self.n_quantizer}"
140
+ if n_quantizer!=self.num_codebooks:
141
+ return einops.rearrange(x, '(T K) -> K T', K=n_quantizer)
142
+ return einops.rearrange(x, '(T K) -> K T', K=self.num_codebooks)
143
+
144
+ # def check_codec_type_from_path(self, path):
145
+ # if self.codec_type == "hifi16k":
146
+ # assert "academicodec_hifi_16k_320d_large_uni" in path
147
+
148
+ def get_codec_type_from_range(self, ids):
149
+ ids_range = [ids.min(), ids.max()]
150
+ codec_range = self.mm_v0_2_cfg["codec_range"]
151
+ for codec_type, r in codec_range.items():
152
+ if ids_range[0] >= r[0] and ids_range[1] <= r[1]:
153
+ return codec_type
154
+ raise ValueError(f"ids_range={ids_range}, codec_range={codec_range}")
155
+
156
+ def npy2ids(self, npy):
157
+ if isinstance(npy, str):
158
+ data = np.load(npy)
159
+ elif isinstance(npy, np.ndarray):
160
+ data = npy
161
+ else:
162
+ raise ValueError(f"not supported type: {type(npy)}")
163
+ # data = data.squeeze()
164
+
165
+ assert len(data.shape)==2, f'data shape: {data.shape} is not (n_codebook, seq_len)'
166
+ data = self.offset_tok_ids(
167
+ data,
168
+ global_offset=self.global_offset,
169
+ codebook_size=self.codebook_size,
170
+ num_codebooks=self.num_codebooks,
171
+ )
172
+ data = self.flatten(data)
173
+ codec_range = self.get_codec_type_from_range(data)
174
+ assert codec_range == self.codec_type, f"get_codec_type_from_range(data)={codec_range}, self.codec_type={self.codec_type}"
175
+ data = data.tolist()
176
+ return data
177
+
178
+ def ids2npy(self, token_ids):
179
+ # make sure token_ids starts with codebook 0
180
+ if isinstance(self.codebook_size, int):
181
+ codebook_0_range = (self.global_offset + self.quantizer_begin*self.codebook_size, self.global_offset + (self.quantizer_begin+1)*self.codebook_size)
182
+ elif isinstance(self.codebook_size, list):
183
+ codebook_0_range = (self.global_offset, self.global_offset + self.codebook_size[0])
184
+ assert token_ids[0] >= codebook_0_range[0] \
185
+ and token_ids[0] < codebook_0_range[1], f"token_ids[0]={token_ids[self.quantizer_begin]}, codebook_0_range={codebook_0_range}"
186
+ data = np.array(token_ids)
187
+ data = self.unflatten(data, n_quantizer=self.n_quantizer)
188
+ data = self.unoffset_tok_ids(
189
+ data,
190
+ global_offset=self.global_offset,
191
+ codebook_size=self.codebook_size,
192
+ num_codebooks=self.num_codebooks,
193
+ )
194
+ return data
195
+
196
+ def npy_to_json_str(self, npy_path):
197
+ data = self.npy2ids(npy_path)
198
+ return json.dumps({"text": data, "src": npy_path, "codec": self.codec_type})
199
+
200
+ def sep(self):
201
+ return ''.join(self.sep)
202
+
203
+ def sep_ids(self):
204
+ return self.sep_ids
baseline_generate/yue/infer_batch.py ADDED
@@ -0,0 +1,904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
4
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
5
+ import re
6
+ import random
7
+ import uuid
8
+ import copy
9
+ import json
10
+ from tqdm import tqdm
11
+ from collections import Counter
12
+ import argparse
13
+ import numpy as np
14
+ import torch
15
+ import torchaudio
16
+ from torchaudio.transforms import Resample
17
+ import soundfile as sf
18
+ from einops import rearrange
19
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
20
+ from omegaconf import OmegaConf
21
+ from codecmanipulator import CodecManipulator
22
+ from mmtokenizer import _MMSentencePieceTokenizer
23
+ from models.soundstream_hubert_new import SoundStream
24
+ from vocoder import build_codec_model, process_audio
25
+ from post_process_audio import replace_low_freq_with_energy_matched
26
+
27
+ os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
28
+
29
+ parser = argparse.ArgumentParser()
30
+ # Model Configuration:
31
+ parser.add_argument("--stage1_model", type=str, default="m-a-p/YuE-s1-7B-anneal-en-cot", help="The model checkpoint path or identifier for the Stage 1 model.")
32
+ parser.add_argument("--stage2_model", type=str, default="m-a-p/YuE-s2-1B-general", help="The model checkpoint path or identifier for the Stage 2 model.")
33
+ parser.add_argument("--max_new_tokens", type=int, default=3000, help="The maximum number of new tokens to generate in one pass during text generation.")
34
+ parser.add_argument("--repetition_penalty", type=float, default=1.1, help="repetition_penalty ranges from 1.0 to 2.0 (or higher in some cases). It controls the diversity and coherence of the audio tokens generated. The higher the value, the greater the discouragement of repetition. Setting value to 1.0 means no penalty.")
35
+ parser.add_argument("--run_n_segments", type=int, default=2, help="The number of segments to process during generation. Each segment is ~30s (with default max_new_tokens=3000). For example: 2=~1min, 6=~3min, 8=~4min.")
36
+ parser.add_argument("--stage2_batch_size", type=int, default=4, help="The batch size used in Stage 2 inference.")
37
+ parser.add_argument(
38
+ "--no_sample",
39
+ action="store_true",
40
+ help="If set, disable sampling in Stage 1 generation (i.e., use deterministic decoding). When enabled, top_p/temperature will be ignored.",
41
+ )
42
+ # Prompt - Batch processing parameters
43
+ parser.add_argument("--jsonl_path", type=str, required=True, help="The file path to a JSONL file containing genre and lyrics for batch processing.")
44
+ parser.add_argument("--start_idx", type=int, default=0, help="Start index in the JSONL file for batch processing.")
45
+ parser.add_argument("--end_idx", type=int, default=-1, help="End index in the JSONL file for batch processing. -1 means process all.")
46
+ parser.add_argument("--use_audio_prompt", action="store_true", help="If set, the model will use an audio file as a prompt during generation. The audio file should be specified using --audio_prompt_path.")
47
+ parser.add_argument("--audio_prompt_path", type=str, default="", help="The file path to an audio file to use as a reference prompt when --use_audio_prompt is enabled.")
48
+ parser.add_argument("--prompt_start_time", type=float, default=0.0, help="The start time in seconds to extract the audio prompt from the given audio file.")
49
+ parser.add_argument("--prompt_end_time", type=float, default=30.0, help="The end time in seconds to extract the audio prompt from the given audio file.")
50
+ parser.add_argument("--use_dual_tracks_prompt", action="store_true", help="If set, the model will use dual tracks as a prompt during generation. The vocal and instrumental files should be specified using --vocal_track_prompt_path and --instrumental_track_prompt_path.")
51
+ parser.add_argument("--vocal_track_prompt_path", type=str, default="", help="The file path to a vocal track file to use as a reference prompt when --use_dual_tracks_prompt is enabled.")
52
+ parser.add_argument("--instrumental_track_prompt_path", type=str, default="", help="The file path to an instrumental track file to use as a reference prompt when --use_dual_tracks_prompt is enabled.")
53
+ # Output
54
+ parser.add_argument("--output_dir", type=str, default="./output", help="The directory where generated outputs will be saved.")
55
+ parser.add_argument("--keep_intermediate", action="store_true", help="If set, intermediate outputs will be saved during processing.")
56
+ parser.add_argument("--disable_offload_model", action="store_true", help="If set, the model will not be offloaded from the GPU to CPU after Stage 1 inference.")
57
+ parser.add_argument("--cuda_idx", type=int, default=0)
58
+ parser.add_argument("--seed", type=int, default=42, help="An integer value to reproduce generation.")
59
+ # Config for xcodec and upsampler
60
+ parser.add_argument('--basic_model_config', default='./xcodec_mini_infer/final_ckpt/config.yaml', help='YAML files for xcodec configurations.')
61
+ parser.add_argument('--resume_path', default='./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth', help='Path to the xcodec checkpoint.')
62
+ parser.add_argument('--config_path', type=str, default='./xcodec_mini_infer/decoders/config.yaml', help='Path to Vocos config file.')
63
+ parser.add_argument('--vocal_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.')
64
+ parser.add_argument('--inst_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.')
65
+ parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.')
66
+
67
+
68
+ args = parser.parse_args()
69
+ if args.use_audio_prompt and not args.audio_prompt_path:
70
+ raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
71
+ if args.use_dual_tracks_prompt and not args.vocal_track_prompt_path and not args.instrumental_track_prompt_path:
72
+ raise FileNotFoundError("Please offer dual tracks prompt filepath using '--vocal_track_prompt_path' and '--inst_decoder_path', when you enable '--use_dual_tracks_prompt'!")
73
+
74
+ stage1_model = args.stage1_model
75
+ stage2_model = args.stage2_model
76
+ cuda_idx = args.cuda_idx
77
+ max_new_tokens = args.max_new_tokens
78
+ do_sample_stage1 = (not args.no_sample)
79
+
80
+ def seed_everything(seed=42):
81
+ random.seed(seed)
82
+ np.random.seed(seed)
83
+ torch.manual_seed(seed)
84
+ torch.cuda.manual_seed_all(seed)
85
+ torch.backends.cudnn.deterministic = True
86
+ torch.backends.cudnn.benchmark = False
87
+
88
+ seed_everything(args.seed)
89
+
90
+ # Read JSONL file
91
+ print(f"Reading JSONL file: {args.jsonl_path}")
92
+ music_data_list = []
93
+ with open(args.jsonl_path, 'r', encoding='utf-8') as f:
94
+ for line in f:
95
+ if line.strip():
96
+ music_data_list.append(json.loads(line))
97
+
98
+ # Determine processing range
99
+ start_idx = args.start_idx
100
+ end_idx = len(music_data_list) if args.end_idx == -1 else min(args.end_idx, len(music_data_list))
101
+ music_data_list = music_data_list[start_idx:end_idx]
102
+ print(f"Total {len(music_data_list)} songs to generate (indices {start_idx} to {end_idx-1})")
103
+
104
+ # Detect processed songs - check completion status of each stage
105
+ def check_song_status(song_idx, output_dir):
106
+ """
107
+ Check song processing status
108
+ Returns: (stage1_done, stage2_done, stage3_done, song_dir, stage1_output_set, stage2_output_dir)
109
+ """
110
+ if not os.path.exists(output_dir):
111
+ return False, False, False, None, None, None
112
+
113
+ # Find song directory (may have multiple, take the latest or first)
114
+ song_dirs = []
115
+ for item in os.listdir(output_dir):
116
+ if item.startswith('song_') and os.path.isdir(os.path.join(output_dir, item)):
117
+ try:
118
+ idx = int(item.split('_')[1])
119
+ if idx == song_idx:
120
+ song_dirs.append(os.path.join(output_dir, item))
121
+ except (ValueError, IndexError):
122
+ continue
123
+
124
+ if not song_dirs:
125
+ return False, False, False, None, None, None
126
+
127
+ # Use the latest directory (sorted by modification time)
128
+ song_dir = max(song_dirs, key=lambda x: os.path.getmtime(x))
129
+
130
+ # Check Stage 1: whether stage1 directory has vtrack and itrack .npy files
131
+ stage1_dir = os.path.join(song_dir, "stage1")
132
+ stage1_done = False
133
+ stage1_output_set = []
134
+ if os.path.exists(stage1_dir):
135
+ stage1_files = [f for f in os.listdir(stage1_dir) if f.endswith('.npy')]
136
+ vtrack_files = [f for f in stage1_files if '_vtrack' in f]
137
+ itrack_files = [f for f in stage1_files if '_itrack' in f]
138
+ if vtrack_files and itrack_files:
139
+ stage1_done = True
140
+ # Build stage1_output_set
141
+ for f in vtrack_files + itrack_files:
142
+ stage1_output_set.append(os.path.join(stage1_dir, f))
143
+
144
+ # Check Stage 2: whether stage2 directory has corresponding .npy files
145
+ stage2_dir = os.path.join(song_dir, "stage2")
146
+ stage2_done = False
147
+ if stage1_done and os.path.exists(stage2_dir):
148
+ stage2_files = [f for f in os.listdir(stage2_dir) if f.endswith('.npy')]
149
+ # Check if all stage1 files have corresponding stage2 files
150
+ if stage1_output_set:
151
+ stage1_basenames = {os.path.basename(f) for f in stage1_output_set}
152
+ stage2_basenames = set(stage2_files)
153
+ if stage1_basenames.issubset(stage2_basenames):
154
+ stage2_done = True
155
+
156
+ # Check Stage 3: whether there is a final mixed file (in song_dir root directory)
157
+ stage3_done = False
158
+ for root, dirs, files in os.walk(song_dir):
159
+ if any(f.endswith('_mixed.mp3') for f in files):
160
+ stage3_done = True
161
+ break
162
+
163
+ return stage1_done, stage2_done, stage3_done, song_dir, stage1_output_set, stage2_dir
164
+
165
+ # Detect processing status of all songs
166
+ song_status_map = {} # {song_idx: (stage1_done, stage2_done, stage3_done, song_dir, stage1_output_set, stage2_output_dir)}
167
+ if os.path.exists(args.output_dir):
168
+ print(f"\nDetecting processed songs...")
169
+ for list_idx in range(len(music_data_list)):
170
+ song_idx = start_idx + list_idx
171
+ stage1_done, stage2_done, stage3_done, song_dir, stage1_output_set, stage2_output_dir = check_song_status(song_idx, args.output_dir)
172
+ if stage1_done or stage2_done or stage3_done:
173
+ song_status_map[song_idx] = (stage1_done, stage2_done, stage3_done, song_dir, stage1_output_set, stage2_output_dir)
174
+
175
+ if song_status_map:
176
+ fully_completed = [idx for idx, (s1, s2, s3, _, _, _) in song_status_map.items() if s3]
177
+ partial_completed = [idx for idx, (s1, s2, s3, _, _, _) in song_status_map.items() if not s3]
178
+ print(f"✓ Found {len(fully_completed)} fully completed songs: {sorted(fully_completed)}")
179
+ if partial_completed:
180
+ print(f"✓ Found {len(partial_completed)} partially completed songs: {sorted(partial_completed)}")
181
+ for idx in sorted(partial_completed):
182
+ s1, s2, s3, _, _, _ = song_status_map[idx]
183
+ status_parts = []
184
+ if s1: status_parts.append("Stage1")
185
+ if s2: status_parts.append("Stage2")
186
+ if s3: status_parts.append("Stage3")
187
+ print(f" Index {idx}: Completed {', '.join(status_parts)}")
188
+ remaining_count = len(music_data_list) - len(fully_completed)
189
+ print(f"✓ Will skip fully completed songs, {remaining_count} songs remaining to process")
190
+ else:
191
+ print(f"✓ No processed songs found, will start from the beginning")
192
+ else:
193
+ print(f"✓ Output directory does not exist, will start from the beginning")
194
+
195
+ # Load tokenizer and model
196
+ device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
197
+ mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
198
+ print("Loading Stage 1 model...")
199
+ model = AutoModelForCausalLM.from_pretrained(
200
+ stage1_model,
201
+ torch_dtype=torch.bfloat16,
202
+ attn_implementation="flash_attention_2", # Using flash_attention_2 for better performance
203
+ # device_map="auto",
204
+ )
205
+ # to device, if gpu is available
206
+ model.to(device)
207
+ model.eval()
208
+
209
+ if torch.__version__ >= "2.0.0":
210
+ try:
211
+ model = torch.compile(model)
212
+ except Exception as e:
213
+ print(f"Warning: torch.compile not available: {e}")
214
+
215
+ codectool = CodecManipulator("xcodec", 0, 1)
216
+ codectool_stage2 = CodecManipulator("xcodec", 0, 8)
217
+ model_config = OmegaConf.load(args.basic_model_config)
218
+ codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
219
+ # Load checkpoint with weights_only=False to allow OmegaConf types
220
+ # Note: Only use this if you trust the checkpoint source
221
+ parameter_dict = torch.load(args.resume_path, map_location='cpu', weights_only=False)
222
+ codec_model.load_state_dict(parameter_dict['codec_model'])
223
+ codec_model.to(device)
224
+ codec_model.eval()
225
+
226
+ class BlockTokenRangeProcessor(LogitsProcessor):
227
+ def __init__(self, start_id, end_id):
228
+ self.blocked_token_ids = list(range(start_id, end_id))
229
+
230
+ def __call__(self, input_ids, scores):
231
+ scores[:, self.blocked_token_ids] = -float("inf")
232
+ return scores
233
+
234
+ def load_audio_mono(filepath, sampling_rate=16000):
235
+ audio, sr = torchaudio.load(filepath)
236
+ # Convert to mono
237
+ audio = torch.mean(audio, dim=0, keepdim=True)
238
+ # Resample if needed
239
+ if sr != sampling_rate:
240
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
241
+ audio = resampler(audio)
242
+ return audio
243
+
244
+ def encode_audio(codec_model, audio_prompt, device, target_bw=0.5):
245
+ if len(audio_prompt.shape) < 3:
246
+ audio_prompt.unsqueeze_(0)
247
+ with torch.no_grad():
248
+ raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=target_bw)
249
+ raw_codes = raw_codes.transpose(0, 1)
250
+ raw_codes = raw_codes.cpu().numpy().astype(np.int16)
251
+ return raw_codes
252
+
253
+ def split_lyrics(lyrics):
254
+ """
255
+ Split lyrics by segments, following YuE official best practices:
256
+
257
+ Official requirements:
258
+ 1. Lyrics should be segmented using structure tags: [verse], [chorus], [bridge], [outro], etc.
259
+ 2. Each segment is separated by two newlines "\n\n"
260
+ 3. Each segment is about 30 seconds (when --max_new_tokens 3000), don't put too many words
261
+ 4. Avoid using [intro] tag (not very stable), recommend starting with [verse] or [chorus]
262
+ 5. Supports multiple languages: English, Chinese, Cantonese, Japanese, Korean, etc.
263
+
264
+ Args:
265
+ lyrics: Raw lyrics string
266
+
267
+ Returns:
268
+ Structured lyrics segment list, each segment in [tag]\ncontent\n\n format
269
+ """
270
+ # Regular expression: match [any tag] and its following content
271
+ # Supports: [Verse 1], [Pre-Chorus], [Chorus (Outro)] and other complex tags
272
+ pattern = r"\[([^\]]+)\](.*?)(?=\[|\Z)"
273
+ segments = re.findall(pattern, lyrics, re.DOTALL)
274
+ structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
275
+ return structured_lyrics
276
+
277
+ def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
278
+ folder_path = os.path.dirname(path)
279
+ if not os.path.exists(folder_path):
280
+ os.makedirs(folder_path)
281
+ limit = 0.99
282
+ max_val = wav.abs().max()
283
+ wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
284
+ torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
285
+
286
+ def stage2_generate(model, prompt, batch_size=16):
287
+ codec_ids = codectool.unflatten(prompt, n_quantizer=1)
288
+ codec_ids = codectool.offset_tok_ids(
289
+ codec_ids,
290
+ global_offset=codectool.global_offset,
291
+ codebook_size=codectool.codebook_size,
292
+ num_codebooks=codectool.num_codebooks,
293
+ ).astype(np.int32)
294
+
295
+ # Prepare prompt_ids based on batch size or single input
296
+ if batch_size > 1:
297
+ codec_list = []
298
+ for i in range(batch_size):
299
+ idx_begin = i * 300
300
+ idx_end = (i + 1) * 300
301
+ codec_list.append(codec_ids[:, idx_begin:idx_end])
302
+
303
+ codec_ids = np.concatenate(codec_list, axis=0)
304
+ prompt_ids = np.concatenate(
305
+ [
306
+ np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
307
+ codec_ids,
308
+ np.tile([mmtokenizer.stage_2], (batch_size, 1)),
309
+ ],
310
+ axis=1
311
+ )
312
+ else:
313
+ prompt_ids = np.concatenate([
314
+ np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
315
+ codec_ids.flatten(), # Flatten the 2D array to 1D
316
+ np.array([mmtokenizer.stage_2])
317
+ ]).astype(np.int32)
318
+ prompt_ids = prompt_ids[np.newaxis, ...]
319
+
320
+ codec_ids = torch.as_tensor(codec_ids).to(device)
321
+ prompt_ids = torch.as_tensor(prompt_ids).to(device)
322
+ len_prompt = prompt_ids.shape[-1]
323
+
324
+ block_list = LogitsProcessorList([BlockTokenRangeProcessor(0, 46358), BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)])
325
+
326
+ # Teacher forcing generate loop
327
+ for frames_idx in range(codec_ids.shape[1]):
328
+ cb0 = codec_ids[:, frames_idx:frames_idx+1]
329
+ prompt_ids = torch.cat([prompt_ids, cb0], dim=1)
330
+ input_ids = prompt_ids
331
+
332
+ with torch.no_grad():
333
+ stage2_output = model.generate(input_ids=input_ids,
334
+ min_new_tokens=7,
335
+ max_new_tokens=7,
336
+ eos_token_id=mmtokenizer.eoa,
337
+ pad_token_id=mmtokenizer.eoa,
338
+ logits_processor=block_list,
339
+ )
340
+
341
+ assert stage2_output.shape[1] - prompt_ids.shape[1] == 7, f"output new tokens={stage2_output.shape[1]-prompt_ids.shape[1]}"
342
+ prompt_ids = stage2_output
343
+
344
+ # Return output based on batch size
345
+ if batch_size > 1:
346
+ output = prompt_ids.cpu().numpy()[:, len_prompt:]
347
+ output_list = [output[i] for i in range(batch_size)]
348
+ output = np.concatenate(output_list, axis=0)
349
+ else:
350
+ output = prompt_ids[0].cpu().numpy()[len_prompt:]
351
+
352
+ return output
353
+
354
+ def sanitize_genres_for_filename(genres, max_length=80):
355
+ """
356
+ Clean and truncate genres string for filename generation
357
+ Ensure filename is not too long (Linux filename limit is 255 bytes)
358
+
359
+ Args:
360
+ genres: Raw genres string
361
+ max_length: Maximum length of genres part (default 80, leaving space for other parameters)
362
+
363
+ Returns:
364
+ Cleaned genres string
365
+ """
366
+ if not genres:
367
+ return "Unknown"
368
+
369
+ # Clean unsafe characters
370
+ genres_clean = re.sub(r'[<>:"/\\|?*\x00-\x1f]', '_', genres)
371
+ genres_clean = genres_clean.strip('_').strip()
372
+
373
+ # If contains comma-separated tags, try to keep first few tags
374
+ if ',' in genres_clean:
375
+ tags = [tag.strip() for tag in genres_clean.split(',')]
376
+ # Try to keep first few tags until reaching length limit
377
+ result_tags = []
378
+ current_length = 0
379
+ for tag in tags:
380
+ if current_length + len(tag) + 1 <= max_length: # +1 for comma
381
+ result_tags.append(tag)
382
+ current_length += len(tag) + 1
383
+ else:
384
+ break
385
+ if result_tags:
386
+ genres_clean = ','.join(result_tags)
387
+ else:
388
+ # If first tag is too long, directly truncate
389
+ genres_clean = tags[0][:max_length] if tags else genres_clean[:max_length]
390
+
391
+ # If still too long, directly truncate
392
+ if len(genres_clean) > max_length:
393
+ genres_clean = genres_clean[:max_length]
394
+
395
+ # Replace spaces with hyphens (for consistency)
396
+ genres_clean = genres_clean.replace(' ', '-')
397
+
398
+ return genres_clean
399
+
400
+ def stage2_inference(model, stage1_output_set, stage2_output_dir, batch_size=4):
401
+ stage2_result = []
402
+ for i in tqdm(range(len(stage1_output_set)), desc="Stage 2 inference"):
403
+ output_filename = os.path.join(stage2_output_dir, os.path.basename(stage1_output_set[i]))
404
+
405
+ if os.path.exists(output_filename):
406
+ print(f'{output_filename} stage2 has done.')
407
+ stage2_result.append(output_filename)
408
+ continue
409
+
410
+ # Load the prompt
411
+ prompt = np.load(stage1_output_set[i]).astype(np.int32)
412
+
413
+ # Only accept 6s segments
414
+ output_duration = prompt.shape[-1] // 50 // 6 * 6
415
+ num_batch = output_duration // 6
416
+
417
+ if num_batch <= batch_size:
418
+ # If num_batch is less than or equal to batch_size, we can infer the entire prompt at once
419
+ output = stage2_generate(model, prompt[:, :output_duration*50], batch_size=num_batch)
420
+ else:
421
+ # If num_batch is greater than batch_size, process in chunks of batch_size
422
+ segments = []
423
+ num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
424
+
425
+ for seg in range(num_segments):
426
+ start_idx = seg * batch_size * 300
427
+ # Ensure the end_idx does not exceed the available length
428
+ end_idx = min((seg + 1) * batch_size * 300, output_duration*50) # Adjust the last segment
429
+ current_batch_size = batch_size if seg != num_segments-1 or num_batch % batch_size == 0 else num_batch % batch_size
430
+ segment = stage2_generate(
431
+ model,
432
+ prompt[:, start_idx:end_idx],
433
+ batch_size=current_batch_size
434
+ )
435
+ segments.append(segment)
436
+
437
+ # Concatenate all the segments
438
+ output = np.concatenate(segments, axis=0)
439
+
440
+ # Process the ending part of the prompt
441
+ if output_duration*50 != prompt.shape[-1]:
442
+ ending = stage2_generate(model, prompt[:, output_duration*50:], batch_size=1)
443
+ output = np.concatenate([output, ending], axis=0)
444
+ output = codectool_stage2.ids2npy(output)
445
+
446
+ # Fix invalid codes (a dirty solution, which may harm the quality of audio)
447
+ # We are trying to find better one
448
+ fixed_output = copy.deepcopy(output)
449
+ for i, line in enumerate(output):
450
+ for j, element in enumerate(line):
451
+ if element < 0 or element > 1023:
452
+ counter = Counter(line)
453
+ most_frequant = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
454
+ fixed_output[i, j] = most_frequant
455
+ # save output
456
+ np.save(output_filename, fixed_output)
457
+ stage2_result.append(output_filename)
458
+ return stage2_result
459
+
460
+ def process_one_song(music_data, song_idx, total_songs):
461
+ """Process Stage 1 for a single song"""
462
+
463
+ # Compatible with genre and description fields
464
+ genres = music_data.get('genre') or music_data.get('description', '')
465
+ lyrics_raw = music_data['lyrics']
466
+ description = music_data.get('description', '')
467
+
468
+ print(f"Description: {description[:100]}...")
469
+ print(f"Genre tags: {genres}")
470
+
471
+ # ===== Print original lyrics =====
472
+ print("\n" + "="*60)
473
+ print("【Original Lyrics (lyrics_raw)】")
474
+ print("="*60)
475
+ print(lyrics_raw)
476
+ print("="*60 + "\n")
477
+
478
+ lyrics = split_lyrics(lyrics_raw)
479
+
480
+ # Validate lyrics format and give warnings (following official best practices)
481
+ print(f"Lyrics analysis: Identified {len(lyrics)} segments")
482
+
483
+ # ===== Print segmented lyrics =====
484
+ print("\n" + "="*60)
485
+ print("【Segmented Lyrics (lyrics)】")
486
+ print("="*60)
487
+ for i, seg in enumerate(lyrics):
488
+ tag = seg.split('\n')[0].strip()
489
+ # Check if unstable [intro] tag is used
490
+ if 'intro' in tag.lower():
491
+ print(f" ⚠️ Warning: Segment {i+1} uses {tag} tag, official recommendation is to avoid [intro], use [verse] or [chorus] instead")
492
+ else:
493
+ print(f" Segment {i+1}. {tag}")
494
+ # Print each segment's content (limit length)
495
+ content = seg.strip()
496
+ if len(content) > 150:
497
+ print(f" Content preview: {content[:150]}...")
498
+ else:
499
+ print(f" Content: {content}")
500
+ print()
501
+ print("="*60 + "\n")
502
+
503
+ # Create output directory for this song
504
+ random_id = uuid.uuid4()
505
+ song_output_dir = os.path.join(args.output_dir, f"song_{song_idx:04d}_{random_id}")
506
+ stage1_output_dir = os.path.join(song_output_dir, "stage1")
507
+ stage2_output_dir = os.path.join(song_output_dir, "stage2")
508
+ os.makedirs(stage1_output_dir, exist_ok=True)
509
+ os.makedirs(stage2_output_dir, exist_ok=True)
510
+
511
+ # Stage 1: Generate audio tokens
512
+ print("--- Stage 1: Generate audio tokens ---")
513
+ stage1_output_set = []
514
+ full_lyrics = "\n".join(lyrics)
515
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
516
+ prompt_texts += lyrics
517
+
518
+ # ===== Print prompt texts passed to model =====
519
+ print("\n" + "="*60)
520
+ print("【Prompt Texts Passed to Model (prompt_texts)】")
521
+ print("="*60)
522
+ print(f"Total {len(prompt_texts)} prompts (first is full prompt, subsequent are segments)\n")
523
+ for i, pt in enumerate(prompt_texts):
524
+ if i == 0:
525
+ print(f"Prompt {i} [Full prompt header]:")
526
+ if len(pt) > 300:
527
+ print(f"{pt[:300]}...")
528
+ else:
529
+ print(pt)
530
+ else:
531
+ print(f"\nPrompt {i} [Segment {i}]:")
532
+ if len(pt) > 200:
533
+ print(f"{pt[:200]}...")
534
+ else:
535
+ print(pt)
536
+ print("="*60 + "\n")
537
+
538
+ output_seq = None
539
+ # Here is suggested decoding config
540
+ top_p = 0.93
541
+ temperature = 1.0
542
+ repetition_penalty = args.repetition_penalty
543
+ if not do_sample_stage1:
544
+ print("Note: --no_sample is enabled, Stage 1 will use deterministic decoding; top_p/temperature will be ignored.")
545
+ # special tokens
546
+ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
547
+ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
548
+ # Format text prompt
549
+ # +1 because prompt_texts[0] is the full prompt which will be skipped, so need len(lyrics)+1 to process all segments
550
+ run_n_segments = min(args.run_n_segments+1, len(lyrics)+1)
551
+
552
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments], desc="Stage1 inference")):
553
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
554
+ guidance_scale = 1.5 if i <=1 else 1.2
555
+
556
+ # ===== Print currently processing segment =====
557
+ if i == 0:
558
+ print(f"\n[Segment {i}] Skipped (full prompt header)")
559
+ else:
560
+ print(f"\n" + "-"*60)
561
+ print(f"[Processing segment {i}/{len(prompt_texts[:run_n_segments])-1}]")
562
+ print("-"*60)
563
+ tag_line = section_text.split('\n')[0] if '\n' in section_text else section_text[:50]
564
+ print(f"Segment tag: {tag_line}")
565
+ print(f"Segment content length: {len(section_text)} characters")
566
+ if len(section_text) > 200:
567
+ print(f"Segment content preview: {section_text[:200]}...")
568
+ else:
569
+ print(f"Segment content: {section_text}")
570
+ print("-"*60)
571
+
572
+ if i==0:
573
+ continue
574
+ if i==1:
575
+ if args.use_dual_tracks_prompt or args.use_audio_prompt:
576
+ if args.use_dual_tracks_prompt:
577
+ vocals_ids = load_audio_mono(args.vocal_track_prompt_path)
578
+ instrumental_ids = load_audio_mono(args.instrumental_track_prompt_path)
579
+ vocals_ids = encode_audio(codec_model, vocals_ids, device, target_bw=0.5)
580
+ instrumental_ids = encode_audio(codec_model, instrumental_ids, device, target_bw=0.5)
581
+ vocals_ids = codectool.npy2ids(vocals_ids[0])
582
+ instrumental_ids = codectool.npy2ids(instrumental_ids[0])
583
+ ids_segment_interleaved = rearrange([np.array(vocals_ids), np.array(instrumental_ids)], 'b n -> (n b)')
584
+ audio_prompt_codec = ids_segment_interleaved[int(args.prompt_start_time*50*2): int(args.prompt_end_time*50*2)]
585
+ audio_prompt_codec = audio_prompt_codec.tolist()
586
+ elif args.use_audio_prompt:
587
+ audio_prompt = load_audio_mono(args.audio_prompt_path)
588
+ raw_codes = encode_audio(codec_model, audio_prompt, device, target_bw=0.5)
589
+ # Format audio prompt
590
+ code_ids = codectool.npy2ids(raw_codes[0])
591
+ audio_prompt_codec = code_ids[int(args.prompt_start_time *50): int(args.prompt_end_time *50)] # 50 is tps of xcodec
592
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
593
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
594
+ head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
595
+ else:
596
+ head_id = mmtokenizer.tokenize(prompt_texts[0])
597
+ prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
598
+ else:
599
+ prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
600
+
601
+ prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
602
+ input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
603
+ # Use window slicing in case output sequence exceeds the context of model
604
+ max_context = 16384-max_new_tokens-1
605
+ if input_ids.shape[-1] > max_context:
606
+ print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
607
+ input_ids = input_ids[:, -(max_context):]
608
+ with torch.no_grad():
609
+ output_seq = model.generate(
610
+ input_ids=input_ids,
611
+ max_new_tokens=max_new_tokens,
612
+ min_new_tokens=100,
613
+ do_sample=do_sample_stage1,
614
+ top_p=top_p,
615
+ temperature=temperature,
616
+ repetition_penalty=repetition_penalty,
617
+ eos_token_id=mmtokenizer.eoa,
618
+ pad_token_id=mmtokenizer.eoa,
619
+ logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
620
+ guidance_scale=guidance_scale,
621
+ )
622
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
623
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
624
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
625
+ if i > 1:
626
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
627
+ else:
628
+ raw_output = output_seq
629
+
630
+ # save raw output and check sanity
631
+ ids = raw_output[0].cpu().numpy()
632
+ soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
633
+ eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
634
+ if len(soa_idx)!=len(eoa_idx):
635
+ raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
636
+
637
+ vocals = []
638
+ instrumentals = []
639
+ range_begin = 1 if args.use_audio_prompt or args.use_dual_tracks_prompt else 0
640
+ for i in range(range_begin, len(soa_idx)):
641
+ codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
642
+ if codec_ids[0] == 32016:
643
+ codec_ids = codec_ids[1:]
644
+ codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
645
+ vocals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[0])
646
+ vocals.append(vocals_ids)
647
+ instrumentals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[1])
648
+ instrumentals.append(instrumentals_ids)
649
+ vocals = np.concatenate(vocals, axis=1)
650
+ instrumentals = np.concatenate(instrumentals, axis=1)
651
+ # Clean genres string to avoid filename being too long
652
+ genres_clean = sanitize_genres_for_filename(genres, max_length=80)
653
+ vocal_save_path = os.path.join(stage1_output_dir, f"{genres_clean}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_{random_id}_vtrack".replace('.', '@')+'.npy')
654
+ inst_save_path = os.path.join(stage1_output_dir, f"{genres_clean}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_{random_id}_itrack".replace('.', '@')+'.npy')
655
+ np.save(vocal_save_path, vocals)
656
+ np.save(inst_save_path, instrumentals)
657
+ stage1_output_set.append(vocal_save_path)
658
+ stage1_output_set.append(inst_save_path)
659
+
660
+ return stage1_output_set, stage2_output_dir, song_output_dir
661
+
662
+ # Load Stage 2 model and vocoder (load only once)
663
+ print("\n" + "="*60)
664
+ print("Loading Stage 2 model...")
665
+ print("="*60)
666
+ model_stage2 = AutoModelForCausalLM.from_pretrained(
667
+ stage2_model,
668
+ torch_dtype=torch.bfloat16,
669
+ attn_implementation="flash_attention_2", # Using flash_attention_2 for better performance
670
+ # device_map="auto",
671
+ )
672
+ model_stage2.to(device)
673
+ model_stage2.eval()
674
+
675
+ if torch.__version__ >= "2.0.0":
676
+ try:
677
+ model_stage2 = torch.compile(model_stage2)
678
+ except Exception as e:
679
+ print(f"Warning: torch.compile not available: {e}")
680
+
681
+ print("Loading vocoder...")
682
+ vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path)
683
+
684
+ # Batch process all songs - process each song completely before continuing to next
685
+ all_results = []
686
+ skipped_count = 0
687
+ for list_idx, music_data in enumerate(music_data_list):
688
+ # Calculate actual song index (considering start_idx offset)
689
+ song_idx = start_idx + list_idx
690
+
691
+ try:
692
+ # Compatible with genre and description fields
693
+ genres = music_data.get('genre') or music_data.get('description', '')
694
+
695
+ # Check processing status
696
+ stage1_done = False
697
+ stage2_done = False
698
+ stage3_done = False
699
+ song_output_dir = None
700
+ stage1_output_set = None
701
+ stage2_output_dir = None
702
+
703
+ if song_idx in song_status_map:
704
+ stage1_done, stage2_done, stage3_done, song_output_dir, stage1_output_set, stage2_output_dir = song_status_map[song_idx]
705
+
706
+ # If all completed, skip
707
+ if stage3_done:
708
+ print(f"\n{'='*60}")
709
+ print(f"⏭️ Skipping song {list_idx+1}/{len(music_data_list)} (index {song_idx}, fully completed)")
710
+ print(f"{'='*60}")
711
+ skipped_count += 1
712
+ continue
713
+
714
+ # Decide which stage to start from based on completion status
715
+ print(f"\n{'='*60}")
716
+ print(f"Starting to process song {list_idx+1}/{len(music_data_list)} (index {song_idx})")
717
+ if stage1_done:
718
+ print(f" ✓ Stage 1 completed, will start from Stage 2")
719
+ if stage2_done:
720
+ print(f" ✓ Stage 2 completed, will start from Stage 3")
721
+ print(f"{'='*60}")
722
+
723
+ # Stage 1: Generate audio tokens (if not completed)
724
+ if not stage1_done:
725
+ stage1_output_set, stage2_output_dir, song_output_dir = process_one_song(music_data, song_idx, len(music_data_list))
726
+ print(f"��� Stage 1 completed, generated {len(stage1_output_set)} files")
727
+ for f in stage1_output_set:
728
+ print(f" - {os.path.basename(f)}")
729
+ else:
730
+ print(f"⏭️ Skipping Stage 1 (completed)")
731
+ print(f" Using existing Stage 1 outputs:")
732
+ for f in stage1_output_set:
733
+ print(f" - {os.path.basename(f)}")
734
+
735
+ # Note: Do not unload Stage 1 model here, as subsequent songs still need it
736
+ # Stage 1 model will be unloaded uniformly after all songs are processed
737
+
738
+ # Stage 2: Process audio tokens (if not completed)
739
+ if not stage2_done:
740
+ print(f"\n--- Stage 2: Processing song {list_idx+1} (index {song_idx}) ---")
741
+ stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=args.stage2_batch_size)
742
+ print(f"✓ Stage 2 completed, generated {len(stage2_result)} files")
743
+ for f in stage2_result:
744
+ print(f" - {os.path.basename(f)}")
745
+ else:
746
+ print(f"\n⏭️ Skipping Stage 2 (completed)")
747
+ # Get existing stage2 results
748
+ stage2_result = []
749
+ if os.path.exists(stage2_output_dir):
750
+ for f in stage1_output_set:
751
+ basename = os.path.basename(f)
752
+ stage2_file = os.path.join(stage2_output_dir, basename)
753
+ if os.path.exists(stage2_file):
754
+ stage2_result.append(stage2_file)
755
+ print(f" Using existing Stage 2 outputs:")
756
+ for f in stage2_result:
757
+ print(f" - {os.path.basename(f)}")
758
+
759
+ # Stage 3: Reconstruct audio and mix (if not completed)
760
+ final_output = None
761
+ if not stage3_done:
762
+ print(f"\n--- Stage 3: Reconstructing audio for song {list_idx+1} (index {song_idx}) ---")
763
+
764
+ # reconstruct tracks
765
+ recons_output_dir = os.path.join(song_output_dir, "recons")
766
+ recons_mix_dir = os.path.join(recons_output_dir, 'mix')
767
+ os.makedirs(recons_mix_dir, exist_ok=True)
768
+ tracks = []
769
+ for npy in stage2_result:
770
+ codec_result = np.load(npy)
771
+ decodec_rlt=[]
772
+ with torch.no_grad():
773
+ decoded_waveform = codec_model.decode(torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
774
+ decoded_waveform = decoded_waveform.cpu().squeeze(0)
775
+ decodec_rlt.append(torch.as_tensor(decoded_waveform))
776
+ decodec_rlt = torch.cat(decodec_rlt, dim=-1)
777
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
778
+ tracks.append(save_path)
779
+ save_audio(decodec_rlt, save_path, 16000)
780
+
781
+ # mix tracks
782
+ recons_mix = None
783
+ for inst_path in tracks:
784
+ try:
785
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
786
+ and '_itrack' in inst_path:
787
+ # find pair
788
+ vocal_path = inst_path.replace('_itrack', '_vtrack')
789
+ if not os.path.exists(vocal_path):
790
+ continue
791
+ # mix
792
+ recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('_itrack', '_mixed'))
793
+ vocal_stem, sr = sf.read(inst_path)
794
+ instrumental_stem, _ = sf.read(vocal_path)
795
+ mix_stem = (vocal_stem + instrumental_stem) / 1
796
+ sf.write(recons_mix, mix_stem, sr)
797
+ except Exception as e:
798
+ print(e)
799
+
800
+ # vocoder to upsample audios
801
+ vocoder_output_dir = os.path.join(song_output_dir, 'vocoder')
802
+ vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
803
+ vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
804
+ os.makedirs(vocoder_mix_dir, exist_ok=True)
805
+ os.makedirs(vocoder_stems_dir, exist_ok=True)
806
+
807
+ for npy in stage2_result:
808
+ if '_itrack' in npy:
809
+ # Process instrumental
810
+ instrumental_output = process_audio(
811
+ npy,
812
+ os.path.join(vocoder_stems_dir, 'itrack.mp3'),
813
+ args.rescale,
814
+ args,
815
+ inst_decoder,
816
+ codec_model
817
+ )
818
+ else:
819
+ # Process vocal
820
+ vocal_output = process_audio(
821
+ npy,
822
+ os.path.join(vocoder_stems_dir, 'vtrack.mp3'),
823
+ args.rescale,
824
+ args,
825
+ vocal_decoder,
826
+ codec_model
827
+ )
828
+
829
+ # mix tracks
830
+ vocoder_mix = None
831
+ try:
832
+ mix_output = instrumental_output + vocal_output
833
+ vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
834
+ save_audio(mix_output, vocoder_mix, 44100, args.rescale)
835
+ print(f"Created mix: {vocoder_mix}")
836
+ except RuntimeError as e:
837
+ print(e)
838
+ print(f"Mix failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
839
+
840
+ # Post process
841
+ if recons_mix and vocoder_mix:
842
+ final_output = os.path.join(song_output_dir, os.path.basename(recons_mix))
843
+ replace_low_freq_with_energy_matched(
844
+ a_file=recons_mix, # 16kHz
845
+ b_file=vocoder_mix, # 48kHz
846
+ c_file=final_output,
847
+ cutoff_freq=5500.0
848
+ )
849
+ print(f"✓ Song {list_idx+1} (index {song_idx}) completed! Output: {final_output}")
850
+ else:
851
+ print(f"\n⏭️ Skipping Stage 3 (completed)")
852
+ # Find final output file (usually in song_dir root directory)
853
+ # First check root directory
854
+ root_files = [f for f in os.listdir(song_output_dir) if f.endswith('_mixed.mp3')]
855
+ if root_files:
856
+ final_output = os.path.join(song_output_dir, root_files[0])
857
+ else:
858
+ # If root directory doesn't have it, traverse subdirectories to find
859
+ for root, dirs, files in os.walk(song_output_dir):
860
+ for f in files:
861
+ if f.endswith('_mixed.mp3'):
862
+ final_output = os.path.join(root, f)
863
+ break
864
+ if final_output:
865
+ break
866
+ if final_output:
867
+ print(f" Final output: {final_output}")
868
+
869
+ all_results.append({
870
+ 'song_idx': song_idx,
871
+ 'genres': genres,
872
+ 'output_path': final_output if recons_mix and vocoder_mix else None
873
+ })
874
+
875
+ except Exception as e:
876
+ print(f"✗ Error processing song {list_idx+1} (index {song_idx}): {e}")
877
+ import traceback
878
+ traceback.print_exc()
879
+ continue
880
+
881
+ # After all songs are processed, unload models to free memory
882
+ if not args.disable_offload_model:
883
+ print("\nCleaning up models to free memory...")
884
+ if 'model' in locals():
885
+ model.cpu()
886
+ del model
887
+ if 'model_stage2' in locals():
888
+ model_stage2.cpu()
889
+ del model_stage2
890
+ torch.cuda.empty_cache()
891
+ print("Models unloaded")
892
+
893
+ print("\n" + "="*60)
894
+ print("Batch generation complete!")
895
+ newly_processed = len([r for r in all_results if r.get('output_path')])
896
+ print(f"✓ Newly processed: {newly_processed} songs")
897
+ if skipped_count > 0:
898
+ print(f"⏭️ Skipped (already completed): {skipped_count} songs")
899
+ print(f"📊 Total completed: {newly_processed + skipped_count} songs")
900
+ print("="*60)
901
+ for result in all_results:
902
+ if result.get('output_path'):
903
+ print(f"Song {result['song_idx']+1}: {result['output_path']}")
904
+
baseline_generate/yue/mmtokenizer.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from abc import abstractmethod
3
+
4
+
5
+ class AbstractTokenizer(ABC):
6
+ """Abstract class for tokenizer."""
7
+
8
+ def __init__(self, name):
9
+ self.name = name
10
+ super().__init__()
11
+
12
+ @property
13
+ @abstractmethod
14
+ def vocab_size(self):
15
+ pass
16
+
17
+ @property
18
+ @abstractmethod
19
+ def vocab(self):
20
+ """Dictionary from vocab text token to id token."""
21
+ pass
22
+
23
+ @property
24
+ @abstractmethod
25
+ def inv_vocab(self):
26
+ """Dictionary from vocab id token to text token."""
27
+ pass
28
+
29
+ @abstractmethod
30
+ def tokenize(self, text):
31
+ pass
32
+
33
+ def detokenize(self, token_ids):
34
+ raise NotImplementedError('detokenizer is not implemented for {} '
35
+ 'tokenizer'.format(self.name))
36
+
37
+ @property
38
+ def cls(self):
39
+ raise NotImplementedError('CLS is not provided for {} '
40
+ 'tokenizer'.format(self.name))
41
+
42
+ @property
43
+ def sep(self):
44
+ raise NotImplementedError('SEP is not provided for {} '
45
+ 'tokenizer'.format(self.name))
46
+
47
+ @property
48
+ def pad(self):
49
+ raise NotImplementedError('PAD is not provided for {} '
50
+ 'tokenizer'.format(self.name))
51
+
52
+ @property
53
+ def eod(self):
54
+ raise NotImplementedError('EOD is not provided for {} '
55
+ 'tokenizer'.format(self.name))
56
+
57
+ @property
58
+ def mask(self):
59
+ raise NotImplementedError('MASK is not provided for {} '
60
+ 'tokenizer'.format(self.name))
61
+
62
+
63
+ class _SentencePieceTokenizer(AbstractTokenizer):
64
+ """SentencePieceTokenizer-Megatron wrapper"""
65
+
66
+ def __init__(self, model_file, vocab_extra_ids=0):
67
+ name = 'SentencePieceTokenizer'
68
+ super().__init__(name)
69
+
70
+ import sentencepiece
71
+ self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
72
+ self._initalize(vocab_extra_ids)
73
+
74
+ def _populate_vocab(self):
75
+ self._vocab = {}
76
+ self._inv_vocab = {}
77
+
78
+ for i in range(len(self.tokenizer)):
79
+ t = self.tokenizer.id_to_piece(i)
80
+ self._inv_vocab[i] = t
81
+ self._vocab[t] = i
82
+
83
+ def _initalize(self, vocab_extra_ids):
84
+ self._populate_vocab()
85
+ self._special_tokens = {}
86
+ self._inv_special_tokens = {}
87
+
88
+ self._t5_tokens = []
89
+
90
+ def _add_special_token(t):
91
+ if t not in self._vocab:
92
+ next_id = len(self._vocab)
93
+ self._vocab[t] = next_id
94
+ self._inv_vocab[next_id] = t
95
+ self._special_tokens[t] = self._vocab[t]
96
+ self._inv_special_tokens[self._vocab[t]] = t
97
+
98
+ _add_special_token('<CLS>')
99
+ self._cls_id = self._vocab['<CLS>']
100
+ _add_special_token('<SEP>')
101
+ self._sep_id = self._vocab['<SEP>']
102
+ _add_special_token('<EOD>')
103
+ self._eod_id = self._vocab['<EOD>']
104
+ _add_special_token('<MASK>')
105
+ self._mask_id = self._vocab['<MASK>']
106
+
107
+ pad_id = self.tokenizer.pad_id()
108
+ try:
109
+ pad_token = self.tokenizer.id_to_piece(pad_id)
110
+ except IndexError:
111
+ pad_token = '<PAD>'
112
+ _add_special_token(pad_token)
113
+ self._pad_id = self._vocab[pad_token]
114
+
115
+ bos_id = self.tokenizer.bos_id()
116
+ try:
117
+ bos_token = self.tokenizer.id_to_piece(bos_id)
118
+ except IndexError:
119
+ bos_token = '<BOS>'
120
+ _add_special_token(bos_token)
121
+ self._bos_id = self._vocab[bos_token]
122
+
123
+ eos_id = self.tokenizer.eos_id()
124
+ try:
125
+ eos_token = self.tokenizer.id_to_piece(eos_id)
126
+ except IndexError:
127
+ eos_token = '<EOS>'
128
+ _add_special_token(eos_token)
129
+ self._eos_id = self._vocab[eos_token]
130
+
131
+ for i in range(vocab_extra_ids):
132
+ t = "<extra_id_{}>".format(i)
133
+ _add_special_token(t)
134
+ self._t5_tokens += [t]
135
+
136
+ @property
137
+ def vocab_size(self):
138
+ return len(self._vocab)
139
+
140
+ @property
141
+ def vocab(self):
142
+ return self._vocab
143
+
144
+ @property
145
+ def inv_vocab(self):
146
+ return self._inv_vocab
147
+
148
+ @property
149
+ def decoder(self):
150
+ return self._inv_vocab
151
+
152
+ @property
153
+ def encoder(self):
154
+ return self._vocab
155
+
156
+ # From:
157
+ # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89
158
+ def tokenize(self, text):
159
+ ids = []
160
+ idx = 0
161
+
162
+ while 1:
163
+ indices = {}
164
+ for token in self._special_tokens:
165
+ try:
166
+ indices[token] = text[idx:].index(token)
167
+ except ValueError:
168
+ continue
169
+ if len(indices) == 0:
170
+ break
171
+
172
+ next_token = min(indices, key=indices.get)
173
+ next_idx = idx + indices[next_token]
174
+
175
+ ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx]))
176
+ ids.append(self._special_tokens[next_token])
177
+ idx = next_idx + len(next_token)
178
+
179
+ ids.extend(self.tokenizer.encode_as_ids(text[idx:]))
180
+ return ids
181
+
182
+ # From:
183
+ # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L125
184
+ def detokenize(self, ids):
185
+ text = ""
186
+ last_i = 0
187
+
188
+ for i, id in enumerate(ids):
189
+ if id in self._inv_special_tokens:
190
+ text += self.tokenizer.decode_ids(ids[last_i:i]) + " "
191
+ text += self._inv_special_tokens[id] + " "
192
+ last_i = i + 1
193
+
194
+ text += self.tokenizer.decode_ids(ids[last_i:])
195
+ return text
196
+
197
+ @property
198
+ def cls(self):
199
+ return self._cls_id
200
+
201
+ @property
202
+ def sep(self):
203
+ return self._sep_id
204
+
205
+ @property
206
+ def pad(self):
207
+ return self._pad_id
208
+
209
+ @property
210
+ def bos_token_id(self):
211
+ return self._bos_id
212
+
213
+ @property
214
+ def bos(self):
215
+ return self._bos_id
216
+
217
+ @property
218
+ def eod(self):
219
+ return self._eod_id
220
+
221
+ @property
222
+ def eos_token_id(self):
223
+ return self._eos_id
224
+
225
+ @property
226
+ def eos(self):
227
+ return self._eos_id
228
+
229
+ @property
230
+ def mask(self):
231
+ return self._mask_id
232
+
233
+ @property
234
+ def additional_special_tokens_ids(self):
235
+ return [self.vocab[k] for k in self._t5_tokens]
236
+
237
+ class _MMSentencePieceTokenizer(_SentencePieceTokenizer):
238
+ """SentencePieceTokenizer-Megatron wrapper"""
239
+
240
+ def __init__(self, model_file, vocab_extra_ids=0):
241
+ super().__init__(model_file, vocab_extra_ids)
242
+
243
+
244
+ def _initalize(self, vocab_extra_ids):
245
+ self._populate_vocab()
246
+ self._special_tokens = {}
247
+ self._inv_special_tokens = {}
248
+
249
+ self._t5_tokens = []
250
+
251
+ def _add_special_token(t):
252
+ if t not in self._vocab:
253
+ next_id = len(self._vocab)
254
+ self._vocab[t] = next_id
255
+ self._inv_vocab[next_id] = t
256
+ self._special_tokens[t] = self._vocab[t]
257
+ self._inv_special_tokens[self._vocab[t]] = t
258
+
259
+ _add_special_token('<CLS>')
260
+ self._cls_id = self._vocab['<CLS>']
261
+ _add_special_token('<SEP>')
262
+ self._sep_id = self._vocab['<SEP>']
263
+ _add_special_token('<EOD>')
264
+ self._eod_id = self._vocab['<EOD>']
265
+ _add_special_token('<MASK>')
266
+ self._mask_id = self._vocab['<MASK>']
267
+
268
+ _add_special_token('<SOA>')
269
+ self._soa_id = self._vocab['<SOA>']
270
+ _add_special_token('<EOA>')
271
+ self._eoa_id = self._vocab['<EOA>']
272
+ _add_special_token('<SOV>')
273
+ self._sov_id = self._vocab['<SOV>']
274
+ _add_special_token('<EOV>')
275
+ self._eov_id = self._vocab['<EOV>']
276
+ _add_special_token('<SOI>')
277
+ self._soi_id = self._vocab['<SOI>']
278
+ _add_special_token('<EOI>')
279
+ self._eoi_id = self._vocab['<EOI>']
280
+ _add_special_token('<s_local>')
281
+ self._s_local_id = self._vocab['<s_local>']
282
+ _add_special_token('<e_local>')
283
+ self._e_local_id = self._vocab['<e_local>']
284
+ _add_special_token('<s_global>')
285
+ self._s_global_id = self._vocab['<s_global>']
286
+ _add_special_token('<e_global>')
287
+ self._e_global_id = self._vocab['<e_global>']
288
+ _add_special_token('<stage_1>')
289
+ self._stage_1_id = self._vocab['<stage_1>']
290
+ _add_special_token('<stage_2>')
291
+ self._stage_2_id = self._vocab['<stage_2>']
292
+ pad_id = self.tokenizer.pad_id()
293
+ try:
294
+ pad_token = self.tokenizer.id_to_piece(pad_id)
295
+ except IndexError:
296
+ pad_token = '<PAD>'
297
+ _add_special_token(pad_token)
298
+ self._pad_id = self._vocab[pad_token]
299
+
300
+ bos_id = self.tokenizer.bos_id()
301
+ try:
302
+ bos_token = self.tokenizer.id_to_piece(bos_id)
303
+ except IndexError:
304
+ bos_token = '<BOS>'
305
+ _add_special_token(bos_token)
306
+ self._bos_id = self._vocab[bos_token]
307
+
308
+ eos_id = self.tokenizer.eos_id()
309
+ try:
310
+ eos_token = self.tokenizer.id_to_piece(eos_id)
311
+ except IndexError:
312
+ eos_token = '<EOS>'
313
+ _add_special_token(eos_token)
314
+ self._eos_id = self._vocab[eos_token]
315
+
316
+ for i in range(vocab_extra_ids):
317
+ t = "<extra_id_{}>".format(i)
318
+ _add_special_token(t)
319
+ self._t5_tokens += [t]
320
+
321
+ @property
322
+ def soa(self):
323
+ return self._soa_id
324
+
325
+ @property
326
+ def eoa(self):
327
+ return self._eoa_id
328
+
329
+ @property
330
+ def sov(self):
331
+ return self._sov_id
332
+
333
+ @property
334
+ def eov(self):
335
+ return self._eov_id
336
+
337
+ @property
338
+ def soi(self):
339
+ return self._soi_id
340
+
341
+ @property
342
+ def eoi(self):
343
+ return self._eoi_id
344
+
345
+ @property
346
+ def s_local(self):
347
+ return self._s_local_id
348
+
349
+ @property
350
+ def e_local(self):
351
+ return self._e_local_id
352
+
353
+ @property
354
+ def s_global(self):
355
+ return self._s_global_id
356
+
357
+ @property
358
+ def e_global(self):
359
+ return self._e_global_id
360
+
361
+ @property
362
+ def stage_1(self):
363
+ return self._stage_1_id
364
+
365
+ @property
366
+ def stage_2(self):
367
+ return self._stage_2_id
data_pipeline/lyrics_gene/__pycache__/filter_all_cn.cpython-311.pyc ADDED
Binary file (12.1 kB). View file
 
data_pipeline/lyrics_gene/__pycache__/filter_all_en.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
data_pipeline/lyrics_gene/__pycache__/gen_lyrics_cn.cpython-311.pyc ADDED
Binary file (31 kB). View file
 
data_pipeline/lyrics_gene/filter_all_cn.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import torch
5
+ from sentence_transformers import SentenceTransformer, util
6
+ from tqdm import tqdm
7
+
8
+ # os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890"
9
+ # os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890"
10
+
11
+ # Use HuggingFace mirror site
12
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
13
+ # Set HuggingFace cache directory so SentenceTransformer can recognize downloaded models
14
+ os.environ["HF_HOME"] = os.path.expanduser("~/.cache/huggingface")
15
+
16
+
17
+ def clean_lyrics(lyrics):
18
+ """
19
+ Clean lyrics by removing segment tags, timestamp tags, and newlines, keeping only pure lyric text
20
+
21
+ Args:
22
+ lyrics: Raw lyrics text (contains segment tags like [Verse 1], timestamps like [00:07.00], and newlines)
23
+
24
+ Returns:
25
+ Cleaned lyrics text (plain text, no tags and newlines)
26
+ """
27
+ # Use regex to remove all [tag] format content (including segment tags and timestamps)
28
+ # Pattern matching [any content]
29
+ cleaned = re.sub(r'\[.*?\]', '', lyrics)
30
+
31
+ # Remove all newlines, replace with spaces
32
+ cleaned = cleaned.replace('\n', ' ')
33
+
34
+ # Remove extra spaces (replace multiple consecutive spaces with single space)
35
+ cleaned = re.sub(r'\s+', ' ', cleaned)
36
+
37
+ # Remove leading and trailing spaces
38
+ cleaned = cleaned.strip()
39
+
40
+ return cleaned
41
+
42
+
43
+ def load_music_data(input_file, max_count=None):
44
+ """
45
+ Load music data from jsonl file
46
+
47
+ Args:
48
+ input_file: Path to input jsonl file
49
+ max_count: Maximum number to read, None means read all
50
+
51
+ Returns:
52
+ List of music data
53
+ """
54
+ music_list = []
55
+ print(f"Loading music data: {input_file}")
56
+ if max_count:
57
+ print(f"Limiting to first {max_count} songs")
58
+
59
+ with open(input_file, 'r', encoding='utf-8') as f:
60
+ for line in tqdm(f, desc="Loading data"):
61
+ try:
62
+ data = json.loads(line.strip())
63
+ # Ensure required fields are present
64
+ if 'description' in data and 'lyrics' in data:
65
+ music_list.append(data)
66
+ # If reached maximum count, stop reading
67
+ if max_count and len(music_list) >= max_count:
68
+ break
69
+ except json.JSONDecodeError:
70
+ continue
71
+ print(f"Successfully loaded {len(music_list)} songs")
72
+ return music_list
73
+
74
+
75
+ def deduplicate_music(music_list, texts, model, threshold=0.90, output_file=None, save_interval=10000, matrix_save_dir=None):
76
+ """
77
+ Deduplicate music data based on text similarity
78
+
79
+ Args:
80
+ music_list: List of music data
81
+ texts: List of texts for comparison
82
+ model: SentenceTransformer model
83
+ threshold: Similarity threshold
84
+ output_file: Output file path, if provided supports incremental saving
85
+ save_interval: Save every N valid songs processed
86
+ matrix_save_dir: Directory to save matrices
87
+
88
+ Returns:
89
+ Deduplicated music data list
90
+ """
91
+ print(f"Computing embeddings for {len(texts)} texts...")
92
+ embeddings = model.encode(texts, convert_to_tensor=True, show_progress_bar=True)
93
+
94
+ print("Computing similarity matrix...")
95
+ cos_scores = util.pytorch_cos_sim(embeddings, embeddings)
96
+
97
+ # Save similarity matrix and embeddings
98
+ if matrix_save_dir:
99
+ os.makedirs(matrix_save_dir, exist_ok=True)
100
+ embeddings_path = os.path.join(matrix_save_dir, 'embeddings.pt')
101
+ cos_scores_path = os.path.join(matrix_save_dir, 'cos_scores.pt')
102
+ print(f"Saving embeddings to: {embeddings_path}")
103
+ torch.save(embeddings.cpu(), embeddings_path)
104
+ print(f"Saving similarity matrix to: {cos_scores_path}")
105
+ torch.save(cos_scores.cpu(), cos_scores_path)
106
+ print("Matrix saving complete!")
107
+
108
+ print(f"Deduplicating (threshold: {threshold})...")
109
+ keep_idx = []
110
+ removed = set()
111
+
112
+ # If output file provided, open in write mode
113
+ f = None
114
+ if output_file:
115
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
116
+ f = open(output_file, 'w', encoding='utf-8')
117
+
118
+ saved_count = 0
119
+
120
+ for i in tqdm(range(len(music_list)), desc="Deduplication progress"):
121
+ if i in removed:
122
+ continue
123
+ keep_idx.append(i)
124
+
125
+ # If incremental saving enabled, save every save_interval songs
126
+ if f and len(keep_idx) - saved_count >= save_interval:
127
+ # Save all valid songs from saved_count to current
128
+ for idx in range(saved_count, len(keep_idx)):
129
+ music = music_list[keep_idx[idx]]
130
+ f.write(json.dumps(music, ensure_ascii=False) + '\n')
131
+ f.flush() # Ensure write to disk
132
+ saved_count = len(keep_idx)
133
+ print(f"Saved {saved_count} valid songs to file")
134
+
135
+ for j in range(i+1, len(music_list)):
136
+ if cos_scores[i][j] > threshold:
137
+ removed.add(j)
138
+
139
+ # Save remaining valid songs
140
+ if f:
141
+ for idx in range(saved_count, len(keep_idx)):
142
+ music = music_list[keep_idx[idx]]
143
+ f.write(json.dumps(music, ensure_ascii=False) + '\n')
144
+ f.close()
145
+ print(f"Saved all {len(keep_idx)} valid songs to file")
146
+
147
+ deduped_music_list = [music_list[i] for i in keep_idx]
148
+ print(f"Deduplication complete: {len(music_list)} -> {len(deduped_music_list)} (removed {len(removed)} songs)")
149
+
150
+ return deduped_music_list
151
+
152
+
153
+ def dedup_by_description_and_lyrics(input_file, output_file, threshold=0.95, max_count=None, device='cuda:1', save_interval=10000, matrix_save_dir=None):
154
+ """
155
+ Method 2: Deduplicate based on description + lyrics
156
+
157
+ Args:
158
+ input_file: Path to input jsonl file
159
+ output_file: Path to output jsonl file
160
+ threshold: Similarity threshold
161
+ max_count: Maximum number to read, None means read all
162
+ device: Device to use, default cuda:1 (GPU1)
163
+ save_interval: Save every N valid songs processed, default 10000
164
+ matrix_save_dir: Directory to save matrices, if provided saves embeddings and similarity matrix
165
+ """
166
+ print("\n========== Method 2: Deduplicate based on description + lyrics ==========")
167
+ print(f"Using device: {device}")
168
+
169
+ # Load data
170
+ music_list = load_music_data(input_file, max_count=max_count)
171
+
172
+ # Extract combined text from description + lyrics
173
+ combined_texts = []
174
+ for music in music_list:
175
+ description = music.get('description', '')
176
+ lyrics = music.get('lyrics', '')
177
+ # Clean lyrics, remove structure tags
178
+ cleaned_lyrics = clean_lyrics(lyrics)
179
+ # Concatenate description and cleaned lyrics (separated by delimiter)
180
+ combined_text = f"{description} [SEP] {cleaned_lyrics}"
181
+ combined_texts.append(combined_text)
182
+
183
+ # Load Chinese model and specify device
184
+ # Check if local model exists, if so use local path directly to avoid re-downloading
185
+ hf_home = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
186
+ model_cache_dir = os.path.join(hf_home, "hub", "models--shibing624--text2vec-bge-large-chinese", "snapshots")
187
+
188
+ # Find model snapshot directory
189
+ local_model_path = None
190
+ if os.path.exists(model_cache_dir):
191
+ snapshots = [d for d in os.listdir(model_cache_dir) if os.path.isdir(os.path.join(model_cache_dir, d))]
192
+ if snapshots:
193
+ # Use latest snapshot (usually unique)
194
+ local_model_path = os.path.join(model_cache_dir, snapshots[0])
195
+ if os.path.exists(os.path.join(local_model_path, "config.json")):
196
+ print(f"Detected local model, using path: {local_model_path}")
197
+ model = SentenceTransformer(local_model_path, device=device)
198
+ else:
199
+ local_model_path = None
200
+
201
+ if local_model_path is None:
202
+ print(f"Loading model to {device}...")
203
+ model = SentenceTransformer('shibing624/text2vec-bge-large-chinese', device=device)
204
+
205
+ # Deduplicate (supports incremental saving)
206
+ deduped_music_list = deduplicate_music(
207
+ music_list,
208
+ combined_texts,
209
+ model,
210
+ threshold,
211
+ output_file=output_file,
212
+ save_interval=save_interval,
213
+ matrix_save_dir=matrix_save_dir
214
+ )
215
+
216
+ # Deduplication function already handled saving, just print info here
217
+ if output_file:
218
+ print(f"✓ Save complete! Remaining {len(deduped_music_list)} songs after deduplication\n")
219
+ else:
220
+ # If no output file provided, save once here (compatibility with old code)
221
+ print(f"Saving results to: {output_file}")
222
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
223
+ with open(output_file, 'w', encoding='utf-8') as f:
224
+ for music in deduped_music_list:
225
+ f.write(json.dumps(music, ensure_ascii=False) + '\n')
226
+ print(f"✓ Save complete! Remaining {len(deduped_music_list)} songs after deduplication\n")
227
+
228
+ return deduped_music_list
229
+
230
+
231
+ if __name__ == '__main__':
232
+ # Input file path
233
+ input_file = 'lrc_4w_single_pro_des.jsonl'
234
+
235
+ # Output file path
236
+ output_file = 'filter_all_4w.jsonl'
237
+
238
+ # Matrix save directory
239
+ matrix_save_dir = 'generate_lrc'
240
+
241
+ # Set maximum read count (for testing, None means read all)
242
+ max_count = None # Test first 5 songs
243
+
244
+ # Deduplicate based on description + lyrics
245
+ print("\nDeduplicating based on description + lyrics")
246
+ dedup_by_description_and_lyrics(
247
+ input_file,
248
+ output_file,
249
+ threshold=0.90,
250
+ max_count=max_count,
251
+ device='cuda:7',
252
+ save_interval=10000, # Save every 10000 valid songs
253
+ matrix_save_dir=matrix_save_dir # Save similarity matrix
254
+ )
255
+ print(f"\nComplete! Results saved to: {output_file}")
256
+ print(f"Similarity matrix saved to: {matrix_save_dir}")
257
+ # Test lyrics cleaning effect
258
+ # print("\n========== Test Lyrics Cleaning Effect ==========")
259
+ # music_list = load_music_data(input_file, max_count=max_count)
260
+
261
+ # print("\n" + "="*80)
262
+ # for i, music in enumerate(music_list, 1):
263
+ # print(f"\n[Song {i}]")
264
+ # print(f"Description: {music.get('description', '')}")
265
+ # print("\n--- Original Lyrics ---")
266
+ # original_lyrics = music.get('lyrics', '')
267
+ # print(original_lyrics[:500] + "..." if len(original_lyrics) > 500 else original_lyrics)
268
+ # print("\n--- Cleaned Lyrics ---")
269
+ # cleaned_lyrics = clean_lyrics(original_lyrics)
270
+ # print(cleaned_lyrics)
271
+ # print("\n" + "-"*80)
272
+
data_pipeline/lyrics_gene/filter_all_en.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import torch
5
+ from sentence_transformers import SentenceTransformer, util
6
+ from tqdm import tqdm
7
+
8
+ # os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890"
9
+ # os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890"
10
+
11
+ # Use HuggingFace mirror site
12
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
13
+ # Set HuggingFace cache directory so SentenceTransformer can recognize downloaded models
14
+ os.environ["HF_HOME"] = os.path.expanduser("~/.cache/huggingface")
15
+
16
+ MODEL_EN = "sentence-transformers/all-MiniLM-L6-v2"
17
+
18
+
19
+ def clean_lyrics(lyrics):
20
+ """
21
+ Clean lyrics by removing segment tags, timestamp tags, and newlines, keeping only pure lyric text
22
+
23
+ Args:
24
+ lyrics: Raw lyrics text (contains segment tags like [Verse 1], timestamps like [00:07.00], and newlines)
25
+
26
+ Returns:
27
+ Cleaned lyrics text (plain text, no tags and newlines)
28
+ """
29
+ # Use regex to remove all [tag] format content (including segment tags and timestamps)
30
+ # Pattern matching [any content]
31
+ cleaned = re.sub(r'\[.*?\]', '', lyrics)
32
+
33
+ # Remove all newlines, replace with spaces
34
+ cleaned = cleaned.replace('\n', ' ')
35
+
36
+ # Remove extra spaces (replace multiple consecutive spaces with single space)
37
+ cleaned = re.sub(r'\s+', ' ', cleaned)
38
+
39
+ # Remove leading and trailing spaces
40
+ cleaned = cleaned.strip()
41
+
42
+ return cleaned
43
+
44
+
45
+ def load_music_data(input_file, max_count=None):
46
+ """
47
+ Load music data from jsonl file
48
+
49
+ Args:
50
+ input_file: Path to input jsonl file
51
+ max_count: Maximum number to read, None means read all
52
+
53
+ Returns:
54
+ List of music data
55
+ """
56
+ music_list = []
57
+ print(f"Loading music data: {input_file}")
58
+ if max_count:
59
+ print(f"Limiting to first {max_count} songs")
60
+
61
+ with open(input_file, 'r', encoding='utf-8') as f:
62
+ for line in tqdm(f, desc="Loading data"):
63
+ try:
64
+ data = json.loads(line.strip())
65
+ # Ensure required fields are present
66
+ if 'description' in data and 'lyrics' in data:
67
+ music_list.append(data)
68
+ # If reached maximum count, stop reading
69
+ if max_count and len(music_list) >= max_count:
70
+ break
71
+ except json.JSONDecodeError:
72
+ continue
73
+ print(f"Successfully loaded {len(music_list)} songs")
74
+ return music_list
75
+
76
+
77
+ def deduplicate_music(music_list, texts, model, threshold=0.90, output_file=None, save_interval=10000, matrix_save_dir=None):
78
+ """
79
+ Deduplicate music data based on text similarity
80
+
81
+ Args:
82
+ music_list: List of music data
83
+ texts: List of texts for comparison
84
+ model: SentenceTransformer model
85
+ threshold: Similarity threshold
86
+ output_file: Output file path, if provided supports incremental saving
87
+ save_interval: Save every N valid songs processed
88
+ matrix_save_dir: Directory to save matrices
89
+
90
+ Returns:
91
+ Deduplicated music data list
92
+ """
93
+ print(f"Computing embeddings for {len(texts)} texts...")
94
+ embeddings = model.encode(texts, convert_to_tensor=True, show_progress_bar=True)
95
+
96
+ print("Computing similarity matrix...")
97
+ cos_scores = util.pytorch_cos_sim(embeddings, embeddings)
98
+
99
+ # Save similarity matrix and embeddings
100
+ if matrix_save_dir:
101
+ os.makedirs(matrix_save_dir, exist_ok=True)
102
+ embeddings_path = os.path.join(matrix_save_dir, 'embeddings.pt')
103
+ cos_scores_path = os.path.join(matrix_save_dir, 'cos_scores.pt')
104
+ print(f"Saving embeddings to: {embeddings_path}")
105
+ torch.save(embeddings.cpu(), embeddings_path)
106
+ print(f"Saving similarity matrix to: {cos_scores_path}")
107
+ torch.save(cos_scores.cpu(), cos_scores_path)
108
+ print("Matrix saving complete!")
109
+
110
+ print(f"Deduplicating (threshold: {threshold})...")
111
+ keep_idx = []
112
+ removed = set()
113
+
114
+ # If output file provided, open in write mode
115
+ f = None
116
+ if output_file:
117
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
118
+ f = open(output_file, 'w', encoding='utf-8')
119
+
120
+ saved_count = 0
121
+
122
+ for i in tqdm(range(len(music_list)), desc="Deduplication progress"):
123
+ if i in removed:
124
+ continue
125
+ keep_idx.append(i)
126
+
127
+ # If incremental saving enabled, save every save_interval songs
128
+ if f and len(keep_idx) - saved_count >= save_interval:
129
+ # Save all valid songs from saved_count to current
130
+ for idx in range(saved_count, len(keep_idx)):
131
+ music = music_list[keep_idx[idx]]
132
+ f.write(json.dumps(music, ensure_ascii=False) + '\n')
133
+ f.flush() # Ensure write to disk
134
+ saved_count = len(keep_idx)
135
+ print(f"Saved {saved_count} valid songs to file")
136
+
137
+ for j in range(i+1, len(music_list)):
138
+ if cos_scores[i][j] > threshold:
139
+ removed.add(j)
140
+
141
+ # Save remaining valid songs
142
+ if f:
143
+ for idx in range(saved_count, len(keep_idx)):
144
+ music = music_list[keep_idx[idx]]
145
+ f.write(json.dumps(music, ensure_ascii=False) + '\n')
146
+ f.close()
147
+ print(f"Saved all {len(keep_idx)} valid songs to file")
148
+
149
+ deduped_music_list = [music_list[i] for i in keep_idx]
150
+ print(f"Deduplication complete: {len(music_list)} -> {len(deduped_music_list)} (removed {len(removed)} songs)")
151
+
152
+ return deduped_music_list
153
+
154
+
155
+ def dedup_by_description_and_lyrics(input_file, output_file, threshold=0.95, max_count=None, device='cuda:1', save_interval=10000, matrix_save_dir=None):
156
+ """
157
+ Method 2: Deduplicate based on description + lyrics
158
+
159
+ Args:
160
+ input_file: Path to input jsonl file
161
+ output_file: Path to output jsonl file
162
+ threshold: Similarity threshold
163
+ max_count: Maximum number to read, None means read all
164
+ device: Device to use, default cuda:1 (GPU1)
165
+ save_interval: Save every N valid songs processed, default 10000
166
+ matrix_save_dir: Directory to save matrices, if provided saves embeddings and similarity matrix
167
+ """
168
+ print("\n========== Method 2: Deduplicate based on description + lyrics ==========")
169
+ print(f"Using device: {device}")
170
+
171
+ # Load data
172
+ music_list = load_music_data(input_file, max_count=max_count)
173
+
174
+ # Extract combined text from description + lyrics
175
+ combined_texts = []
176
+ for music in music_list:
177
+ description = music.get('description', '')
178
+ lyrics = music.get('lyrics', '')
179
+ # Clean lyrics, remove structure tags
180
+ cleaned_lyrics = clean_lyrics(lyrics)
181
+ # Concatenate description and cleaned lyrics (separated by delimiter)
182
+ combined_text = f"{description} [SEP] {cleaned_lyrics}"
183
+ combined_texts.append(combined_text)
184
+
185
+ # Load English model
186
+ print(f"Loading English model {MODEL_EN} to {device}...")
187
+ model = SentenceTransformer(MODEL_EN, device=device)
188
+
189
+ # Deduplicate (supports incremental saving)
190
+ deduped_music_list = deduplicate_music(
191
+ music_list,
192
+ combined_texts,
193
+ model,
194
+ threshold,
195
+ output_file=output_file,
196
+ save_interval=save_interval,
197
+ matrix_save_dir=matrix_save_dir
198
+ )
199
+
200
+ # Deduplication function already handled saving, just print info here
201
+ if output_file:
202
+ print(f"✓ Save complete! Remaining {len(deduped_music_list)} songs after deduplication\n")
203
+ else:
204
+ # If no output file provided, save once here (compatibility with old code)
205
+ print(f"Saving results to: {output_file}")
206
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
207
+ with open(output_file, 'w', encoding='utf-8') as f:
208
+ for music in deduped_music_list:
209
+ f.write(json.dumps(music, ensure_ascii=False) + '\n')
210
+ print(f"✓ Save complete! Remaining {len(deduped_music_list)} songs after deduplication\n")
211
+
212
+ return deduped_music_list
213
+
214
+
215
+ if __name__ == '__main__':
216
+ # Input file path
217
+ input_file = 'en_lrc_4w_single_pro_des.jsonl'
218
+
219
+ # Output file path
220
+ output_file = 'filter_en_single_4w(0.9).jsonl'
221
+
222
+ # Matrix save directory
223
+ matrix_save_dir = 'en_matrix'
224
+
225
+ # Set maximum read count (for testing, None means read all)
226
+ max_count = None # Test first 5 songs
227
+
228
+ # Deduplicate based on description + lyrics
229
+ print("\nDeduplicating based on description + lyrics")
230
+ dedup_by_description_and_lyrics(
231
+ input_file,
232
+ output_file,
233
+ threshold=0.90,
234
+ max_count=max_count,
235
+ device='cuda:7',
236
+ save_interval=10000, # Save every 10000 valid songs
237
+ matrix_save_dir=matrix_save_dir # Save similarity matrix
238
+ )
239
+ print(f"\nComplete! Results saved to: {output_file}")
240
+ print(f"Similarity matrix saved to: {matrix_save_dir}")
241
+ # Test lyrics cleaning effect
242
+ # print("\n========== Test Lyrics Cleaning Effect ==========")
243
+ # music_list = load_music_data(input_file, max_count=max_count)
244
+
245
+ # print("\n" + "="*80)
246
+ # for i, music in enumerate(music_list, 1):
247
+ # print(f"\n[Song {i}]")
248
+ # print(f"Description: {music.get('description', '')}")
249
+ # print("\n--- Original Lyrics ---")
250
+ # original_lyrics = music.get('lyrics', '')
251
+ # print(original_lyrics[:500] + "..." if len(original_lyrics) > 500 else original_lyrics)
252
+ # print("\n--- Cleaned Lyrics ---")
253
+ # cleaned_lyrics = clean_lyrics(original_lyrics)
254
+ # print(cleaned_lyrics)
255
+ # print("\n" + "-"*80)
256
+
data_pipeline/lyrics_gene/gen_lyrics_cn.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import random
5
+ import re
6
+ from openai import OpenAI
7
+ from tqdm import tqdm
8
+ import threading
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
10
+
11
+ # Set environment variables
12
+ # Note: Set these environment variables before running the script
13
+ # export OPENAI_API_KEY="your-api-key"
14
+ # export OPENAI_BASE_URL="https://api.openai.com/v1" # or your custom API URL
15
+ if not os.environ.get("OPENAI_API_KEY"):
16
+ os.environ["OPENAI_API_KEY"] = "" # Replace with your API key or set via environment variable
17
+ if not os.environ.get("OPENAI_BASE_URL"):
18
+ os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1" # Replace with API URL or set via environment variable
19
+
20
+ # Initialize client
21
+ client = OpenAI()
22
+
23
+
24
+ def _extract_lyrics_timestamps(lyrics_text):
25
+ """
26
+ Extract timestamps from lyrics and convert to seconds
27
+ Args:
28
+ lyrics_text: Lyrics string
29
+ Returns:
30
+ List[float]: Timestamps in order (seconds)
31
+ """
32
+ if not isinstance(lyrics_text, str):
33
+ return []
34
+ pattern = re.compile(r'\[(\d{2}):(\d{2})(?:\.(\d{2}))?\]')
35
+ timestamps = []
36
+ for match in pattern.finditer(lyrics_text):
37
+ minutes = int(match.group(1))
38
+ seconds = int(match.group(2))
39
+ fraction = match.group(3)
40
+ total_seconds = minutes * 60 + seconds
41
+ if fraction is not None:
42
+ divisor = 100 if len(fraction) == 2 else 10 ** len(fraction)
43
+ total_seconds += int(fraction) / divisor
44
+ timestamps.append(total_seconds)
45
+ return timestamps
46
+
47
+
48
+ def _validate_timestamps(lyrics_text, min_last_timestamp=170, max_interval=35):
49
+ """
50
+ Validate if lyrics timestamps meet requirements
51
+ Args:
52
+ lyrics_text: Lyrics string
53
+ min_last_timestamp: Minimum value of last timestamp (seconds)
54
+ max_interval: Maximum interval between last two timestamps (seconds)
55
+ Returns:
56
+ bool: Whether validation passes
57
+ """
58
+ timestamps = _extract_lyrics_timestamps(lyrics_text)
59
+ if len(timestamps) < 2:
60
+ print("Validation failed: Timestamp count less than 2")
61
+ return False
62
+ last = timestamps[-1]
63
+ second_last = timestamps[-2]
64
+ if last < min_last_timestamp:
65
+ print(f"Validation failed: Last timestamp {last:.2f}s is less than {min_last_timestamp}s")
66
+ return False
67
+ if last - second_last > max_interval:
68
+ print(f"Validation failed: Interval between last two timestamps {last - second_last:.2f}s is greater than {max_interval}s")
69
+ return False
70
+ return True
71
+
72
+
73
+ def chat_gpt(text, model='gpt-4o-mini'):
74
+ while True:
75
+ try:
76
+ # Call OpenAI chat completions API
77
+ completion = client.chat.completions.create(
78
+ model=model, # Use GPT-4o-mini model
79
+ messages=[
80
+ {"role": "system", "content": "You are a helpful assistant."},
81
+ {"role": "user", "content": text}
82
+ ]
83
+ )
84
+ # Get reply content
85
+ if getattr(completion.choices[0].message, 'content', None):
86
+ content = completion.choices[0].message.content.strip()
87
+ return content
88
+ else:
89
+ print('error_wait_2s')
90
+ except Exception as e:
91
+ print(f"Error: {e}")
92
+ time.sleep(2)
93
+
94
+
95
+ def chat_gpt_call(text, model='gpt-4o-mini'):
96
+ # Call OpenAI chat completions API
97
+ completion = client.chat.completions.create(
98
+ model=model, # Use GPT-4o-mini model
99
+ messages=[
100
+ {"role": "system", "content": "You are a helpful assistant."},
101
+ {"role": "user", "content": text}
102
+ ]
103
+ )
104
+ # Get reply content
105
+ if getattr(completion.choices[0].message, 'content', None):
106
+ content = completion.choices[0].message.content.strip()
107
+ return content
108
+ else:
109
+ print('error_wait_2s')
110
+
111
+
112
+ def generate_music_descriptions(all_music_data, index_pool, output_file, file_lock, sample_size=20, model='gpt-4o-mini', max_retries=0):
113
+ """
114
+ Read music data file, randomly sample, call GPT to generate new music descriptions and lyrics
115
+
116
+ Args:
117
+ all_music_data: All music data list
118
+ index_pool: Index pool object (thread-safe)
119
+ output_file: Output jsonl file path
120
+ file_lock: File write lock
121
+ sample_size: Number of random samples
122
+ model: Model name to use
123
+ max_retries: Maximum retry count
124
+
125
+ Returns:
126
+ (used_indices, success_count): List of used indices and count of successful generations
127
+ """
128
+ # duration_ranges = [
129
+ # ("3.00", "3.15", 50), ("3.15", "3.30", 50), ("3.30", "3.45", 60),
130
+ # ("3.45", "4.00", 60),
131
+ # ("4.00", "4.15", 70), ("4.15", "4.30", 70), ("4.30", "4.45", 70)
132
+ # ]
133
+ duration_ranges = [
134
+ ("3.00", "3.15", 60), ("3.15", "3.30", 70), ("3.30", "3.45", 80),("3.45", "4.00", 90),
135
+ ("4.00", "4.15", 100), ("4.15", "4.30", 100), ("4.30", "4.45", 100)
136
+ ]
137
+ selected_range = random.choice(duration_ranges)
138
+ require_length = selected_range[2]
139
+
140
+ # Directly convert to timestamp format (strictly corresponds to left and right ends of tuple)
141
+ start_timestamp = f"[{selected_range[0].replace('.', ':')}.00]"
142
+ end_timestamp = f"[{selected_range[1].replace('.', ':')}.00]"
143
+
144
+ # Generate duration description
145
+ start_duration = f"{selected_range[0].replace('.', '分')}秒"
146
+ end_duration = f"{selected_range[1].replace('.', '分')}秒"
147
+
148
+ # Generate random timestamps for examples (randomly generated within time range)
149
+ # Parse time string to minutes and seconds
150
+ start_parts = selected_range[0].split('.')
151
+ end_parts = selected_range[1].split('.')
152
+
153
+ start_minutes = int(start_parts[0])
154
+ start_seconds = int(start_parts[1])
155
+ end_minutes = int(end_parts[0])
156
+ end_seconds = int(end_parts[1])
157
+
158
+ # Convert to total seconds
159
+ start_total_seconds = start_minutes * 60 + start_seconds
160
+ end_total_seconds = end_minutes * 60 + end_seconds
161
+
162
+ # Randomly generate within range
163
+ example1_seconds = random.randint(start_total_seconds, end_total_seconds)
164
+ example2_seconds = random.randint(start_total_seconds, end_total_seconds)
165
+
166
+ example1_minutes = example1_seconds // 60
167
+ example1_secs = example1_seconds % 60
168
+ example2_minutes = example2_seconds // 60
169
+ example2_secs = example2_seconds % 60
170
+
171
+ example1_timestamp = f"[{example1_minutes:02d}:{example1_secs:02d}.00]"
172
+ example2_timestamp = f"[{example2_minutes:02d}:{example2_secs:02d}.00]"
173
+
174
+ # Get sample indices from index pool (thread-safe)
175
+ selected_indices = index_pool.get_indices(sample_size)
176
+ if not selected_indices:
177
+ return [], 0
178
+
179
+ sample_data = [all_music_data[i] for i in selected_indices]
180
+
181
+ # Extract all unique styles
182
+ styles = []
183
+ for data in sample_data:
184
+ style = data.get('style', '')
185
+ if style and style not in styles:
186
+ styles.append(style)
187
+
188
+ styles_text = "、".join(styles)
189
+
190
+ # Build example text - include all sampled data (excluding style)
191
+ examples = []
192
+ for i, data in enumerate(sample_data, 1):
193
+ lyrics_text = " ".join(data.get('lyrics', [])) if isinstance(data.get('lyrics'), list) else data.get('lyrics', '')
194
+ description = data.get('description', '')
195
+ examples.append(f"示例{i}:\ndescription: {description}\nlyrics: {lyrics_text}")
196
+
197
+ examples_text = "\n\n".join(examples)
198
+
199
+ prompt = f"""生成2首完整的歌曲,每首歌必须满足以下硬性指标:
200
+ - 严禁生成小于{require_length}行的歌词!
201
+ - 每首歌的歌词行数必须严格大于{require_length}行,这是硬性要求!
202
+ - 最后一句时间戳必须在{start_timestamp}到{end_timestamp}之间
203
+ - 两首歌的时长、行数必须有差异,严禁最后的时间戳均相同
204
+ - 相邻歌词行的时间戳间隔不得超过10秒!必须保证时间戳连续自然递进
205
+ - 严禁出现如"[03:25.00]在心中[04:25.00]最后一行歌词"的生硬间隔,严禁超过10s的间隔
206
+ 如果生成的歌曲不满足以上任意一项,则视为不合格,请重新生成。
207
+ 请生成2首新的、具有多样性的音乐描述和LRC格式歌词,语言为中文。
208
+
209
+
210
+ 创作要求:
211
+ 1.风格与流派需要确保多样性。
212
+ 2.Description标签化要求(必须严格遵守):
213
+ description字段必须使用结构化的标签格式,包括以下标签,用逗号分隔:
214
+ - 音乐风格标签
215
+ - 音乐流派标签
216
+ - 乐器标签
217
+ - 情感基调标签
218
+ - 氛围标签
219
+ - 演唱方式和人声标签,仅限男声或女生二选一,单人独唱
220
+ 注意:每个标签简洁明了,多个同类标签可用斜杠分隔(如"钢琴/小提琴")
221
+ 3.歌词创造力:lyrics应该具有深度和艺术性:
222
+ - 主题可以涉及爱情、人生、社会、自然、哲思、梦想、回忆等各个方面
223
+ - 运用丰富的文学手法:比喻、意象、对比、排比等
224
+ - 情感真挚,注重韵律和节奏感
225
+ - 可以是叙事性、抒情性或意识流风格
226
+ 4.歌词结构和长度要求(必须严格遵守):
227
+ - lyrics必须按照以下结构组织,并使用段落标签标注每个部分
228
+ - 结构顺序必须严格遵循该顺序,共8个段落标签:[Verse 1]主歌1 → [Pre-Chorus]预副歌 → [Chorus]副歌 → [Verse 2]主歌2 → [Pre-Chorus]预副歌 → [Chorus]副歌 → [Bridge]桥段 → [Chorus (Outro)]副歌(结尾)
229
+ - 一首歌词的段落标签只有8个,即[Verse 1]和[Verse 2]只出现一次,[Pre-Chorus]和[Chorus]各出现两次,[Bridge]和[Chorus (Outro)]各出现一次,禁止额外添加或重复更多的段落标签
230
+ - 每个段落标签(如[Verse 1]、[Chorus]等)必须独占一行,后面紧跟该段落的LRC格式歌词
231
+ - 段落之间用空行分隔
232
+ - **总行数要求**:整首歌必须包含至少{require_length}行带时间戳的歌词(不包括段落标签行和空行)
233
+ 5.LRC格式强制规则(必须��格遵守):
234
+ - 每行歌词格式必须为 `[mm:ss.xx]歌词内容`,时间戳与歌词间无空格,歌词内容需完整连贯
235
+ - **每一行只能包含一小句歌词**,遇到逗号、句号等标点符号时必须换行。
236
+ - **严禁将多句歌词合并在同一行**
237
+ - 时间戳需自然分布,**第一句歌词起始时间不得为 [00:00.00]**,需考虑前奏空白(建议从[00:05.00]到[00:15.00]之间开始)
238
+ - 时间戳间隔要求多样性:每首歌内部的时间戳间隔必须多样化,多采用小数点数间隔,严禁使用固定间隔:
239
+ * 同一首歌内必须包含多种不同的间隔,不要所有句子都使用相同间隔(如不要全部都是4秒间隔)
240
+ * 根据歌词内容的情感强度和音乐节拍来动态调整间隔
241
+ * 相邻歌词行的间隔应该有所变化,体现音乐的节奏起伏
242
+ - 时间戳分配应根据歌曲的风格、情感、节奏来合理推测,而非机械地按照歌词长度分配
243
+ - 每行歌词长度应自然变化,切勿长度一致
244
+ - **歌曲总时长必须达到{start_duration}到{end_duration}(即最后一句时间戳必须在{start_timestamp}到{end_timestamp}之间)这是硬性要求!**
245
+ 6.歌词长度要求:lyrics字段的歌词行数必须大于{require_length}行,若生成长度过短请重新生成。
246
+ 7.独特性和原创性:每首作品都应该是独一无二的,避免简单重复示例的内容。
247
+ 8.格式要求:
248
+ - 直接返回JSON数组格式,包含2个歌曲对象,每个对象只有description和lyrics两个字段
249
+ - description字段:必须是标签格式,不是叙述性文本
250
+ - lyrics字段:带段落标签的LRC格式字符串
251
+ - 严禁在JSON中插入任何额外的符号、标记、注释或说明文字
252
+
253
+ LRC格式示例(带段落标签):
254
+ [Verse 1]
255
+ [00:08.00]第一句歌词
256
+ [00:12.50]第二句歌词
257
+ [00:17.20]第三句歌词
258
+
259
+ [Pre-Chorus]
260
+ [00:22.00]预副歌歌词
261
+ [00:26.50]预副歌歌词
262
+
263
+ [Chorus]
264
+ [00:31.00]副歌歌词
265
+ [00:35.50]副歌歌词
266
+
267
+ 负面示例(禁止出现):
268
+ - 错误:[01:30.00](钢琴间奏) - 禁止在时间戳后加括号注释
269
+ - 错误:[00:00.00]开始的歌词 - 第一句不能从00:00.00开始
270
+ - 错误: [00:05.00]在那片熟悉的田野,阳光洒满金色的麦穗 - 严禁多句歌词放在同一行
271
+
272
+ 现在,请充分发挥你的创造力,生成2首全新的、完整的音乐描述和LRC格式歌词作品。
273
+ 特别提醒:每首歌必须是完整歌曲,不要缩写或省略!必须包含完整的8个段落(Verse 1, Pre-Chorus, Chorus, Verse 2, Pre-Chorus, Chorus, Bridge, Chorus Outro),严格确保大于{require_length}行歌词。
274
+
275
+ 直接返回JSON数组格式:
276
+ [
277
+ {{"description": "...", "lyrics": "..."}},
278
+ {{"description": "...", "lyrics": "..."}}
279
+ ]"""
280
+ # Try to generate with retry mechanism
281
+ for attempt in range(max_retries + 1):
282
+ try:
283
+ # Call OpenAI API
284
+ completion = client.chat.completions.create(
285
+ model=model,
286
+ messages=[
287
+ {"role": "system", "content": f"You are a creative music lyricist and composer. Please generate diverse and creative music tag-based descriptions and LRC format lyrics with song structure tags. CRITICAL REQUIREMENTS: 1) Description must be structured tags separated by commas, NOT narrative text. 2) Return ONLY pure, valid JSON format without any extra symbols, markers, or comments. 3) Each song must include structure tags like [Verse 1], [Chorus], [Bridge], etc., followed by LRC format lyrics [mm:ss.xx]lyric_content. 4) MANDATORY: Each song must have MORE than {require_length} lines of lyrics with timestamps. "},
288
+ {"role": "user", "content": prompt}
289
+ ],
290
+ n=1,
291
+ temperature=1.0,
292
+ )
293
+ #print(prompt)
294
+ # Extract all responses
295
+ results = []
296
+ filtered_count = 0
297
+ last_content = None
298
+
299
+ for i, choice in enumerate(completion.choices, 1):
300
+ try:
301
+ content = choice.message.content.strip()
302
+ last_content = content
303
+ print(f"\n=== GPT Response {i} ===")
304
+ print(content)
305
+ print("=" * 50)
306
+ # Try to extract JSON content
307
+ if "```json" in content:
308
+ content = content.split("```json")[1].split("```")[0].strip()
309
+ elif "```" in content:
310
+ content = content.split("```")[1].split("```")[0].strip()
311
+
312
+ # Clean trailing commas in JSON (extra commas)
313
+ # Remove commas after last element of object/array
314
+ content = re.sub(r',(\s*[}\]])', r'\1', content)
315
+
316
+ # Parse JSON array
317
+ result_array = json.loads(content)
318
+
319
+ # Ensure it's a list
320
+ if isinstance(result_array, list):
321
+ # Validate each object in array
322
+ for song in result_array:
323
+ if isinstance(song, dict) and 'description' in song and 'lyrics' in song:
324
+ if _validate_timestamps(song.get('lyrics', '')):
325
+ results.append(song)
326
+ else:
327
+ filtered_count += 1
328
+ # If returned a single object (compatibility with old format)
329
+ elif isinstance(result_array, dict) and 'description' in result_array and 'lyrics' in result_array:
330
+ if _validate_timestamps(result_array.get('lyrics', '')):
331
+ results.append(result_array)
332
+ else:
333
+ filtered_count += 1
334
+
335
+ except json.JSONDecodeError:
336
+ continue
337
+
338
+ if filtered_count:
339
+ print(f"Total {filtered_count} songs filtered due to timestamp validation failure")
340
+
341
+ # Print parsing results
342
+ print(f"\nParsing complete, results length: {len(results)}")
343
+ print(f"Results content: {results}")
344
+ print(start_duration, end_duration,example1_timestamp,example2_timestamp,require_length)
345
+
346
+ # If parsed result length is not 5, write model response content to test.txt
347
+ if len(results) != 2:
348
+ print(f"Warning: Parsed result length is not 2, actual is {len(results)}, will write to test.txt")
349
+ with open('test.txt', 'w', encoding='utf-8') as f:
350
+ if last_content is not None:
351
+ f.write(last_content)
352
+ print("Written to test.txt file")
353
+
354
+ # Check if successfully generated 50 songs (10 responses * 5 each)
355
+ if len(results) >= 50:
356
+ # Append save results to file (use lock to ensure thread safety)
357
+ with file_lock:
358
+ with open(output_file, 'a', encoding='utf-8') as f:
359
+ for result in results[:50]: # Only save first 50 songs
360
+ f.write(json.dumps(result, ensure_ascii=False) + '\n')
361
+
362
+ return selected_indices, min(len(results), 50)
363
+ elif attempt < max_retries:
364
+ print(f"Only successfully parsed {len(results)}/50 songs, retrying...")
365
+ time.sleep(2)
366
+ else:
367
+ # Last attempt, save even if not 50 songs
368
+ if len(results) > 0:
369
+ with file_lock:
370
+ with open(output_file, 'a', encoding='utf-8') as f:
371
+ for result in results:
372
+ f.write(json.dumps(result, ensure_ascii=False) + '\n')
373
+ return selected_indices, len(results)
374
+
375
+ except Exception as e:
376
+ if attempt < max_retries:
377
+ print(f"Error occurred during generation: {e}, retrying...")
378
+ time.sleep(2)
379
+ else:
380
+ print(f"Generation failed: {e}")
381
+ return selected_indices, 0
382
+
383
+ return selected_indices, 0
384
+
385
+
386
+ class IndexPool:
387
+ """Thread-safe index pool with automatic reset support"""
388
+
389
+ def __init__(self, total_size, selected_file):
390
+ self.total_size = total_size
391
+ self.selected_file = selected_file
392
+ self.lock = threading.Lock()
393
+ self.available_indices = []
394
+ self.selected_indices = set()
395
+ self.reset_count = 0 # Record reset count
396
+
397
+ # Load selected indices from file
398
+ self._load_selected_indices()
399
+ # Initialize available indices
400
+ self._reset_pool()
401
+
402
+ def _load_selected_indices(self):
403
+ """Load selected indices from file"""
404
+ if os.path.exists(self.selected_file):
405
+ with open(self.selected_file, 'r', encoding='utf-8') as f:
406
+ for line in f:
407
+ self.selected_indices.add(int(line.strip()))
408
+
409
+ def _reset_pool(self):
410
+ """Reset index pool"""
411
+ # Calculate available indices
412
+ self.available_indices = [i for i in range(self.total_size) if i not in self.selected_indices]
413
+ random.shuffle(self.available_indices) # Shuffle order
414
+
415
+ if len(self.available_indices) == 0:
416
+ # If no available indices, all have been used, reset selected_indices
417
+ self.reset_count += 1
418
+ print(f"\nIndex pool exhausted, resetting pool for the {self.reset_count}th time, re-selecting from {self.total_size} songs")
419
+ self.selected_indices.clear()
420
+ self.available_indices = list(range(self.total_size))
421
+ random.shuffle(self.available_indices)
422
+
423
+ def get_indices(self, count):
424
+ """
425
+ Thread-safe get specified number of indices
426
+
427
+ Args:
428
+ count: Number of indices needed
429
+
430
+ Returns:
431
+ List of selected indices
432
+ """
433
+ with self.lock:
434
+ # Check if pool needs to be reset
435
+ if len(self.available_indices) < count:
436
+ self._reset_pool()
437
+
438
+ # Get indices
439
+ selected = self.available_indices[:count]
440
+ self.available_indices = self.available_indices[count:]
441
+
442
+ # Add to selected set
443
+ for idx in selected:
444
+ self.selected_indices.add(idx)
445
+
446
+ # Write to file
447
+ with open(self.selected_file, 'a', encoding='utf-8') as f:
448
+ for idx in selected:
449
+ f.write(f"{idx}\n")
450
+
451
+ return selected
452
+
453
+ def get_stats(self):
454
+ """Get statistics"""
455
+ with self.lock:
456
+ return {
457
+ 'available': len(self.available_indices),
458
+ 'selected': len(self.selected_indices),
459
+ 'reset_count': self.reset_count
460
+ }
461
+
462
+
463
+ def batch_generate_music(input_file, output_file, selected_file, total_songs=1000, sample_size=20, model='gpt-4o-mini', num_threads=10):
464
+ """
465
+ Batch generate music descriptions and lyrics (multi-threaded version)
466
+
467
+ Args:
468
+ input_file: Path to input jsonl file
469
+ output_file: Path to output jsonl file
470
+ selected_file: Path to file recording selected indices
471
+ total_songs: Total number of songs to generate
472
+ sample_size: Number of samples to extract each time
473
+ model: Model name to use
474
+ num_threads: Number of threads
475
+ """
476
+ # Load all music data
477
+ print("Loading music data...")
478
+ all_music_data = []
479
+ with open(input_file, 'r', encoding='utf-8') as f:
480
+ for line in f:
481
+ data = json.loads(line.strip())
482
+ all_music_data.append(data)
483
+ print(f"Loaded {len(all_music_data)} songs")
484
+
485
+ # Create thread-safe index pool
486
+ index_pool = IndexPool(len(all_music_data), selected_file)
487
+ stats = index_pool.get_stats()
488
+ print(f"Currently selected indices: {stats['selected']}")
489
+ print(f"Currently available indices: {stats['available']}")
490
+
491
+ # Calculate number of calls needed (5 songs per call)
492
+ num_iterations = (total_songs + 1) // 2 # Round up
493
+ print(f"Need to call {num_iterations} times to generate approximately {total_songs} songs (5 per call)")
494
+ print(f"Using {num_threads} threads for parallel processing\n")
495
+
496
+ # Create file write lock
497
+ file_lock = threading.Lock()
498
+
499
+ # Statistics
500
+ total_generated = 0
501
+ generated_lock = threading.Lock()
502
+
503
+ def worker_task(task_id):
504
+ """Worker thread task"""
505
+ try:
506
+ used_indices, success_count = generate_music_descriptions(
507
+ all_music_data=all_music_data,
508
+ index_pool=index_pool,
509
+ output_file=output_file,
510
+ file_lock=file_lock,
511
+ sample_size=sample_size,
512
+ model=model,
513
+ max_retries=0 # Retry
514
+ )
515
+ return success_count
516
+ except Exception as e:
517
+ print(f"Task {task_id} failed: {e}")
518
+ return 0
519
+
520
+ # Use thread pool and progress bar
521
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
522
+ # Submit all tasks
523
+ futures = {executor.submit(worker_task, i): i for i in range(num_iterations)}
524
+
525
+ # Use tqdm to show progress
526
+ with tqdm(total=num_iterations, desc="Generation progress", unit="batch") as pbar:
527
+ for future in as_completed(futures):
528
+ success_count = future.result()
529
+
530
+ with generated_lock:
531
+ total_generated += success_count
532
+
533
+ # Get current statistics
534
+ stats = index_pool.get_stats()
535
+
536
+ # Update progress bar
537
+ pbar.set_postfix({
538
+ 'Batch': f'{success_count}/5',
539
+ 'Total': total_generated,
540
+ 'Remaining': stats['available'],
541
+ 'Resets': stats['reset_count']
542
+ })
543
+ pbar.update(1)
544
+
545
+ # Final statistics
546
+ stats = index_pool.get_stats()
547
+ print(f"\nGeneration complete!")
548
+ print(f"Total generated: {total_generated} songs")
549
+ print(f"Used {stats['selected']} indices")
550
+ print(f"Remaining available indices: {stats['available']}")
551
+ print(f"Pool reset count: {stats['reset_count']}")
552
+
553
+
554
+ if __name__ == '__main__':
555
+ input_file = 'tagged_musics.jsonl'
556
+ output_file = 'generate_lrc_5mini.jsonl'
557
+ selected_file = 'selected.txt'
558
+ # n=1, max_retries=0, sample 10 songs each time, generate 5 new songs
559
+ batch_generate_music(
560
+ input_file=input_file,
561
+ output_file=output_file,
562
+ selected_file=selected_file,
563
+ total_songs=10,
564
+ sample_size=4,
565
+ model='gpt-5-mini',
566
+ num_threads=5 # Test with 1 thread first
567
+ )
568
+ # Append to txt file
data_pipeline/lyrics_gene/gen_lyrics_en.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import random
5
+ import re
6
+ from openai import OpenAI
7
+ from tqdm import tqdm
8
+ import threading
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
10
+
11
+ # Set environment variables
12
+ # Note: Set these environment variables before running the script
13
+ # export OPENAI_API_KEY="your-api-key"
14
+ # export OPENAI_BASE_URL="https://api.openai.com/v1" # or your custom API URL
15
+ if not os.environ.get("OPENAI_API_KEY"):
16
+ os.environ["OPENAI_API_KEY"] = "" # Replace with your API key or set via environment variable
17
+ if not os.environ.get("OPENAI_BASE_URL"):
18
+ os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1" # Replace with API URL or set via environment variable
19
+
20
+ # Initialize client
21
+ client = OpenAI()
22
+
23
+
24
+ def _extract_lyrics_timestamps(lyrics_text):
25
+ """
26
+ Extract timestamps from lyrics and convert to seconds
27
+ Args:
28
+ lyrics_text: Lyrics string
29
+ Returns:
30
+ List[float]: Timestamps in order (seconds)
31
+ """
32
+ if not isinstance(lyrics_text, str):
33
+ return []
34
+ pattern = re.compile(r'\[(\d{2}):(\d{2})(?:\.(\d{2}))?\]')
35
+ timestamps = []
36
+ for match in pattern.finditer(lyrics_text):
37
+ minutes = int(match.group(1))
38
+ seconds = int(match.group(2))
39
+ fraction = match.group(3)
40
+ total_seconds = minutes * 60 + seconds
41
+ if fraction is not None:
42
+ divisor = 100 if len(fraction) == 2 else 10 ** len(fraction)
43
+ total_seconds += int(fraction) / divisor
44
+ timestamps.append(total_seconds)
45
+ return timestamps
46
+
47
+
48
+ def _validate_timestamps(lyrics_text, min_last_timestamp=170, max_interval=35):
49
+ """
50
+ Validate if timestamps in lyrics meet requirements
51
+ Args:
52
+ lyrics_text: Lyrics string
53
+ min_last_timestamp: Minimum value of last timestamp (seconds)
54
+ max_interval: Maximum interval between last two timestamps (seconds)
55
+ Returns:
56
+ bool: Whether validation passed
57
+ """
58
+ timestamps = _extract_lyrics_timestamps(lyrics_text)
59
+ if len(timestamps) < 2:
60
+ print("Validation failed: Timestamp count less than 2")
61
+ return False
62
+ last = timestamps[-1]
63
+ second_last = timestamps[-2]
64
+ if last < min_last_timestamp:
65
+ print(f"Validation failed: Last timestamp {last:.2f}s is less than {min_last_timestamp}s")
66
+ return False
67
+ if last - second_last > max_interval:
68
+ print(f"Validation failed: Interval between last two timestamps {last - second_last:.2f}s is greater than {max_interval}s")
69
+ return False
70
+ return True
71
+
72
+
73
+ def chat_gpt(text, model='gpt-4o-mini'):
74
+ while True:
75
+ try:
76
+ # Call OpenAI chat completions API
77
+ completion = client.chat.completions.create(
78
+ model=model, # Use GPT-4o-mini model
79
+ messages=[
80
+ {"role": "system", "content": "You are a helpful assistant."},
81
+ {"role": "user", "content": text}
82
+ ]
83
+ )
84
+ # Get response content
85
+ if getattr(completion.choices[0].message, 'content', None):
86
+ content = completion.choices[0].message.content.strip()
87
+ return content
88
+ else:
89
+ print('error_wait_2s')
90
+ except Exception as e:
91
+ print(f"Error: {e}")
92
+ time.sleep(2)
93
+
94
+
95
+ def chat_gpt_call(text, model='gpt-4o-mini'):
96
+ # Call OpenAI chat completions API
97
+ completion = client.chat.completions.create(
98
+ model=model, # Use GPT-4o-mini model
99
+ messages=[
100
+ {"role": "system", "content": "You are a helpful assistant."},
101
+ {"role": "user", "content": text}
102
+ ]
103
+ )
104
+ # Get response content
105
+ if getattr(completion.choices[0].message, 'content', None):
106
+ content = completion.choices[0].message.content.strip()
107
+ return content
108
+ else:
109
+ print('error_wait_2s')
110
+
111
+
112
+ def generate_music_descriptions(all_music_data, index_pool, output_file, file_lock, sample_size=20, model='gpt-4o-mini', max_retries=0):
113
+ """
114
+ Read music data file, randomly sample, call GPT to generate new music descriptions and lyrics
115
+
116
+ Args:
117
+ all_music_data: List of all music data
118
+ index_pool: Index pool object (thread-safe)
119
+ output_file: Path to output jsonl file
120
+ file_lock: File write lock
121
+ sample_size: Number of samples to randomly extract
122
+ model: Model name to use
123
+ max_retries: Maximum retry count
124
+
125
+ Returns:
126
+ (used_indices, success_count): List of used indices and number of successful generations
127
+ """
128
+ # duration_ranges = [
129
+ # ("3.00", "3.15", 50), ("3.15", "3.30", 50), ("3.30", "3.45", 60),
130
+ # ("3.45", "4.00", 60),
131
+ # ("4.00", "4.15", 70), ("4.15", "4.30", 70), ("4.30", "4.45", 70)
132
+ # ]
133
+ duration_ranges = [
134
+ ("3.00", "3.15", 60), ("3.15", "3.30", 70), ("3.30", "3.45", 80),("3.45", "4.00", 90),
135
+ ("4.00", "4.15", 100), ("4.15", "4.30", 100), ("4.30", "4.45", 100)
136
+ ]
137
+ selected_range = random.choice(duration_ranges)
138
+ require_length = selected_range[2]
139
+
140
+ # Directly convert to timestamp format (strictly corresponding to left and right ends of tuple)
141
+ start_timestamp = f"[{selected_range[0].replace('.', ':')}.00]"
142
+ end_timestamp = f"[{selected_range[1].replace('.', ':')}.00]"
143
+
144
+ # Generate duration description
145
+ # Convert seconds to minutes and seconds format
146
+ start_seconds = float(selected_range[0])
147
+ start_minutes = int(start_seconds // 60)
148
+ start_secs = int(start_seconds % 60)
149
+ start_duration = f"{start_minutes}min {start_secs}sec"
150
+
151
+ end_seconds = float(selected_range[1])
152
+ end_minutes = int(end_seconds // 60)
153
+ end_secs = int(end_seconds % 60)
154
+ end_duration = f"{end_minutes}min {end_secs}sec"
155
+
156
+ # Generate random timestamps in examples (randomly generated within time range)
157
+ # Parse time string to minutes and seconds
158
+ start_parts = selected_range[0].split('.')
159
+ end_parts = selected_range[1].split('.')
160
+
161
+ start_minutes = int(start_parts[0])
162
+ start_seconds = int(start_parts[1])
163
+ end_minutes = int(end_parts[0])
164
+ end_seconds = int(end_parts[1])
165
+
166
+ # Convert to total seconds
167
+ start_total_seconds = start_minutes * 60 + start_seconds
168
+ end_total_seconds = end_minutes * 60 + end_seconds
169
+
170
+ # Randomly generate within range
171
+ example1_seconds = random.randint(start_total_seconds, end_total_seconds)
172
+ example2_seconds = random.randint(start_total_seconds, end_total_seconds)
173
+
174
+ example1_minutes = example1_seconds // 60
175
+ example1_secs = example1_seconds % 60
176
+ example2_minutes = example2_seconds // 60
177
+ example2_secs = example2_seconds % 60
178
+
179
+ example1_timestamp = f"[{example1_minutes:02d}:{example1_secs:02d}.00]"
180
+ example2_timestamp = f"[{example2_minutes:02d}:{example2_secs:02d}.00]"
181
+
182
+ # Get sample indices from index pool (thread-safe)
183
+ selected_indices = index_pool.get_indices(sample_size)
184
+ if not selected_indices:
185
+ return [], 0
186
+
187
+ sample_data = [all_music_data[i] for i in selected_indices]
188
+
189
+ # Extract all unique styles
190
+ styles = []
191
+ for data in sample_data:
192
+ style = data.get('style', '')
193
+ if style and style not in styles:
194
+ styles.append(style)
195
+
196
+ styles_text = "、".join(styles)
197
+
198
+ # Build example text - include all sampled data (excluding style)
199
+ examples = []
200
+ for i, data in enumerate(sample_data, 1):
201
+ lyrics_text = " ".join(data.get('lyrics', [])) if isinstance(data.get('lyrics'), list) else data.get('lyrics', '')
202
+ description = data.get('description', '')
203
+ examples.append(f"Example {i}:\ndescription: {description}\nlyrics: {lyrics_text}")
204
+
205
+ examples_text = "\n\n".join(examples)
206
+
207
+ prompt = f"""Generate 2 complete songs. Each song must meet the following hard requirements:
208
+ - Strictly forbidden to generate lyrics with fewer than {require_length} lines!
209
+ - The number of lyric lines for each song must be strictly greater than {require_length}. This is a hard requirement!
210
+ - The timestamp of the final line must be between {start_timestamp} and {end_timestamp}.
211
+ - The two songs must differ in duration and line count; their final timestamps must not be identical.
212
+ - The timestamp interval between adjacent lyric lines must not exceed 10 seconds! Timestamps must be continuous and progress naturally.
213
+ - Awkward gaps like "[03:25.00]in the heart[04:25.00]the last lyric" are strictly forbidden. Do not exceed a 10-second interval.
214
+ - It is strictly forbidden to repeat the entire structure or its sections after one iteration is complete. It is also strictly forbidden to repeat the same lyric line multiple times.
215
+ If any of the above requirements are not met, the generation is considered a failure. Please regenerate.
216
+ Please generate 2 new, diverse music descriptions and LRC format lyrics. The language should be English.
217
+
218
+
219
+ Creative Requirements:
220
+ 1. Style and Genre must be diverse.
221
+ 2. Description Tagging Requirements (Must be strictly followed):
222
+ The description field must use a structured tag format, including the following tags, separated by commas:
223
+ - Music Style tag
224
+ - Music Genre tag
225
+ - Instruments tag
226
+ - Emotional Tone tag
227
+ - Mood/Atmosphere tag
228
+ - Vocal Style and Voice tag, limited to either "male voice" or "female voice", solo performance only.
229
+ Note: Each tag should be concise. Multiple tags of the same category can be separated by a slash (e.g., "Piano/Violin").
230
+ 3. Lyric Creativity: The lyrics should have depth and artistry:
231
+ - Themes can cover various aspects such as love, life, society, nature, philosophy, dreams, memories, etc.
232
+ - Use rich literary devices: metaphors, imagery, contrast, parallelism, etc.
233
+ - Express sincere emotions with a focus on rhyme and rhythm.
234
+ - The style can be narrative, lyrical, or stream-of-consciousness.
235
+ 4. Lyric Structure and Length Requirements (Must be strictly followed):
236
+ - The lyrics must be organized using the following structure, with section tags annotating each part.
237
+ - The structure must strictly follow this order, for a total of 8 section tags: [Verse 1] → [Pre-Chorus] → [Chorus] → [Verse 2] → [Pre-Chorus] → [Chorus] → [Bridge] → [Chorus (Outro)].
238
+ - A single song can only have these 8 section tags. [Verse 1] and [Verse 2] appear once; [Pre-Chorus] and [Chorus] appear twice; [Bridge] and [Chorus (Outro)] appear once. Do not add or repeat extra section tags.
239
+ - Each section tag (e.g., [Verse 1], [Chorus]) must be on its own line, immediately followed by the LRC format lyrics for that section.
240
+ - Separate sections with a blank line.
241
+ - **Total Line Count Requirement**: The entire song must contain at least {require_length} lines of timestamped lyrics (not including section tags or blank lines).
242
+ 5. LRC Format Mandatory Rules (Must be strictly followed):
243
+ - Each line of lyrics must be in the format `[mm:ss.xx]Lyric content`, with no space between the timestamp and the lyrics. The lyric content should be coherent.
244
+ - **Each line must contain only one short phrase of lyrics.** Start a new line when encountering punctuation like commas or periods.
245
+ - **Strictly forbidden to merge multiple sentences or clauses onto the same line.**
246
+ - Timestamps must be distributed naturally. **The first line's timestamp must not be [00:00.00]**. Allow for an instrumental intro (suggestion: start between [00:05.00] and [00:15.00]).
247
+ - Timestamp intervals must be varied: The intervals within each song must be diverse, often using decimal values. Do not use a fixed interval:
248
+ * A single song must contain a variety of different intervals; do not use the same interval for all lines (e.g., not all 4-second gaps).
249
+ * Dynamically adjust intervals based on the emotional intensity and rhythm of the lyrics.
250
+ * The gap between adjacent lines should vary to reflect the musical rhythm.
251
+ - Timestamp allocation should be reasonably inferred based on the song's style, emotion, and rhythm, not mechanically assigned based on lyric length.
252
+ - The length of each lyric line should vary naturally; do not make them all uniform.
253
+ - **The total song duration must be between {start_duration} and {end_duration} (meaning the final line's timestamp must be between {start_timestamp} and {end_timestamp}). This is a hard requirement!**
254
+ 6. Lyric Length Requirement: The number of lyric lines in the lyrics field must be greater than {require_length}. If the generated length is too short, please regenerate.
255
+ 7. Uniqueness and Originality: Each piece should be unique. Avoid simply repeating the content from examples.
256
+ 8. Format Requirements:
257
+ - Directly return a JSON array containing 2 song objects. Each object must have only "description" and "lyrics" fields.
258
+ - `description` field: Must be in tag format, not narrative text.
259
+ - `lyrics` field: A string in LRC format with section tags.
260
+ - Strictly forbidden to insert any extra symbols, markers, comments, or explanatory text within the JSON.
261
+
262
+ LRC Format Example (with section tags):
263
+ [Verse 1]
264
+ [00:08.00]First line of lyrics
265
+ [00:12.50]Second line of lyrics
266
+ [00:17.20]Third line of lyrics
267
+
268
+ [Pre-Chorus]
269
+ [00:22.00]Pre-chorus lyrics
270
+ [00:26.50]Pre-chorus lyrics
271
+
272
+ [Chorus]
273
+ [00:31.00]Chorus lyrics
274
+ [00:35.50]Chorus lyrics
275
+
276
+ Negative Examples (to avoid):
277
+ - Incorrect: [01:30.00](Piano Interlude) - Do not add parenthetical comments after the timestamp.
278
+ - Incorrect: [00:00.00]Starting lyric - The first line cannot start at 00:00.00.
279
+ - Incorrect: [00:05.00]In the familiar field, the sun casts golden rays upon the wheat - Strictly forbidden to place multiple clauses on the same line.
280
+ - Incorrect: [03:00.00] In the light of hope[03:05.50] In the light of hope[03:10.20] In the light of hope -Excessive repetition of the exact same lyric line is strictly forbidden. Lyrical content must show variation.
281
+ Now, please fully unleash your creativity and generate 2 new, complete works of music descriptions and LRC format lyrics.
282
+ Special Reminder: Each song must be complete, not abbreviated or omitted! It must contain the full 8 sections (Verse 1, Pre-Chorus, Chorus, Verse 2, Pre-Chorus, Chorus, Bridge, Chorus Outro) and strictly ensure more than {require_length} lines of lyrics.
283
+
284
+ Directly return in JSON array format:
285
+ [
286
+ {{"description": "...", "lyrics": "..."}},
287
+ {{"description": "...", "lyrics": "..."}}
288
+ ]"""
289
+ # Try to generate with retry mechanism
290
+ for attempt in range(max_retries + 1):
291
+ try:
292
+ # Call OpenAI API
293
+ completion = client.chat.completions.create(
294
+ model=model,
295
+ messages=[
296
+ {"role": "system", "content": f"You are a creative music lyricist and composer. Please generate diverse and creative music tag-based descriptions and LRC format lyrics with song structure tags. CRITICAL REQUIREMENTS: 1) Description must be structured tags separated by commas, NOT narrative text. 2) Return ONLY pure, valid JSON format without any extra symbols, markers, or comments. 3) Each song must include structure tags like [Verse 1], [Chorus], [Bridge], etc., followed by LRC format lyrics [mm:ss.xx]lyric_content. 4) MANDATORY: Each song must have MORE than {require_length} lines of lyrics with timestamps. "},
297
+ {"role": "user", "content": prompt}
298
+ ],
299
+ n=1,
300
+ temperature=1.0,
301
+ )
302
+ #print(prompt)
303
+ # Extract all responses
304
+ results = []
305
+ filtered_count = 0
306
+ last_content = None
307
+
308
+ for i, choice in enumerate(completion.choices, 1):
309
+ try:
310
+ content = choice.message.content.strip()
311
+ last_content = content
312
+ print(f"\n=== GPT Response {i} ===")
313
+ print(content)
314
+ print("=" * 50)
315
+ # Try to extract JSON content
316
+ if "```json" in content:
317
+ content = content.split("```json")[1].split("```")[0].strip()
318
+ elif "```" in content:
319
+ content = content.split("```")[1].split("```")[0].strip()
320
+
321
+ # Clean trailing commas in JSON (extra commas)
322
+ # Remove commas after last element of object/array
323
+ content = re.sub(r',(\s*[}\]])', r'\1', content)
324
+
325
+ # Parse JSON array
326
+ result_array = json.loads(content)
327
+
328
+ # Ensure it's a list
329
+ if isinstance(result_array, list):
330
+ # Validate each object in array
331
+ for song in result_array:
332
+ if isinstance(song, dict) and 'description' in song and 'lyrics' in song:
333
+ if _validate_timestamps(song.get('lyrics', '')):
334
+ results.append(song)
335
+ else:
336
+ filtered_count += 1
337
+ # If returned a single object (compatibility with old format)
338
+ elif isinstance(result_array, dict) and 'description' in result_array and 'lyrics' in result_array:
339
+ if _validate_timestamps(result_array.get('lyrics', '')):
340
+ results.append(result_array)
341
+ else:
342
+ filtered_count += 1
343
+
344
+ except json.JSONDecodeError:
345
+ continue
346
+
347
+ if filtered_count:
348
+ print(f"Total {filtered_count} songs filtered due to timestamp validation failure")
349
+
350
+ # Print parsing results
351
+ print(f"\nParsing complete, results length: {len(results)}")
352
+ print(f"Results content: {results}")
353
+ print(start_duration, end_duration,example1_timestamp,example2_timestamp,require_length)
354
+
355
+ # If parsed result length is not 2, write model response content to test.txt
356
+ if len(results) != 2:
357
+ print(f"Warning: Parsed result length is not 2, actual is {len(results)}, will write to test.txt")
358
+ with open('test.txt', 'w', encoding='utf-8') as f:
359
+ if last_content is not None:
360
+ f.write(last_content)
361
+ print("Written to test.txt file")
362
+
363
+ # Check if successfully generated 50 songs (10 responses * 5 each)
364
+ if len(results) >= 50:
365
+ # Append save results to file (use lock to ensure thread safety)
366
+ with file_lock:
367
+ with open(output_file, 'a', encoding='utf-8') as f:
368
+ for result in results[:50]: # Only save first 50 songs
369
+ f.write(json.dumps(result, ensure_ascii=False) + '\n')
370
+
371
+ return selected_indices, min(len(results), 50)
372
+ elif attempt < max_retries:
373
+ print(f"Only successfully parsed {len(results)}/50 songs, retrying...")
374
+ time.sleep(2)
375
+ else:
376
+ # Last attempt, save even if not 50 songs
377
+ if len(results) > 0:
378
+ with file_lock:
379
+ with open(output_file, 'a', encoding='utf-8') as f:
380
+ for result in results:
381
+ f.write(json.dumps(result, ensure_ascii=False) + '\n')
382
+ return selected_indices, len(results)
383
+
384
+ except Exception as e:
385
+ if attempt < max_retries:
386
+ print(f"Error occurred during generation: {e}, retrying...")
387
+ time.sleep(2)
388
+ else:
389
+ print(f"Generation failed: {e}")
390
+ return selected_indices, 0
391
+
392
+ return selected_indices, 0
393
+
394
+
395
+ class IndexPool:
396
+ """Thread-safe index pool with automatic reset support"""
397
+
398
+ def __init__(self, total_size, selected_file):
399
+ self.total_size = total_size
400
+ self.selected_file = selected_file
401
+ self.lock = threading.Lock()
402
+ self.available_indices = []
403
+ self.selected_indices = set()
404
+ self.reset_count = 0 # Record reset count
405
+
406
+ # Load selected indices from file
407
+ self._load_selected_indices()
408
+ # Initialize available indices
409
+ self._reset_pool()
410
+
411
+ def _load_selected_indices(self):
412
+ """Load selected indices from file"""
413
+ if os.path.exists(self.selected_file):
414
+ with open(self.selected_file, 'r', encoding='utf-8') as f:
415
+ for line in f:
416
+ self.selected_indices.add(int(line.strip()))
417
+
418
+ def _reset_pool(self):
419
+ """Reset index pool"""
420
+ # Calculate available indices
421
+ self.available_indices = [i for i in range(self.total_size) if i not in self.selected_indices]
422
+ random.shuffle(self.available_indices) # Shuffle order
423
+
424
+ if len(self.available_indices) == 0:
425
+ # If no available indices, all have been used, reset selected_indices
426
+ self.reset_count += 1
427
+ print(f"\nIndex pool exhausted, resetting pool for the {self.reset_count}th time, re-selecting from {self.total_size} songs")
428
+ self.selected_indices.clear()
429
+ self.available_indices = list(range(self.total_size))
430
+ random.shuffle(self.available_indices)
431
+
432
+ def get_indices(self, count):
433
+ """
434
+ Thread-safe get specified number of indices
435
+
436
+ Args:
437
+ count: Number of indices needed
438
+
439
+ Returns:
440
+ List of selected indices
441
+ """
442
+ with self.lock:
443
+ # Check if pool needs to be reset
444
+ if len(self.available_indices) < count:
445
+ self._reset_pool()
446
+
447
+ # Get indices
448
+ selected = self.available_indices[:count]
449
+ self.available_indices = self.available_indices[count:]
450
+
451
+ # Add to selected set
452
+ for idx in selected:
453
+ self.selected_indices.add(idx)
454
+
455
+ # Write to file
456
+ with open(self.selected_file, 'a', encoding='utf-8') as f:
457
+ for idx in selected:
458
+ f.write(f"{idx}\n")
459
+
460
+ return selected
461
+
462
+ def get_stats(self):
463
+ """Get statistics"""
464
+ with self.lock:
465
+ return {
466
+ 'available': len(self.available_indices),
467
+ 'selected': len(self.selected_indices),
468
+ 'reset_count': self.reset_count
469
+ }
470
+
471
+
472
+ def batch_generate_music(input_file, output_file, selected_file, total_songs=1000, sample_size=20, model='gpt-4o-mini', num_threads=10):
473
+ """
474
+ Batch generate music descriptions and lyrics (multi-threaded version)
475
+
476
+ Args:
477
+ input_file: Path to input jsonl file
478
+ output_file: Path to output jsonl file
479
+ selected_file: Path to file recording selected indices
480
+ total_songs: Total number of songs to generate
481
+ sample_size: Number of samples to extract each time
482
+ model: Model name to use
483
+ num_threads: Number of threads
484
+ """
485
+ # Load all music data
486
+ print("Loading music data...")
487
+ all_music_data = []
488
+ with open(input_file, 'r', encoding='utf-8') as f:
489
+ for line in f:
490
+ data = json.loads(line.strip())
491
+ all_music_data.append(data)
492
+ print(f"Loaded {len(all_music_data)} songs")
493
+
494
+ # Create thread-safe index pool
495
+ index_pool = IndexPool(len(all_music_data), selected_file)
496
+ stats = index_pool.get_stats()
497
+ print(f"Currently selected indices: {stats['selected']}")
498
+ print(f"Currently available indices: {stats['available']}")
499
+
500
+ # Calculate number of calls needed (5 songs per call)
501
+ num_iterations = (total_songs + 1) // 2 # Round up
502
+ print(f"Need to call {num_iterations} times to generate approximately {total_songs} songs (5 per call)")
503
+ print(f"Using {num_threads} threads for parallel processing\n")
504
+
505
+ # Create file write lock
506
+ file_lock = threading.Lock()
507
+
508
+ # Statistics
509
+ total_generated = 0
510
+ generated_lock = threading.Lock()
511
+
512
+ def worker_task(task_id):
513
+ """Worker thread task"""
514
+ try:
515
+ used_indices, success_count = generate_music_descriptions(
516
+ all_music_data=all_music_data,
517
+ index_pool=index_pool,
518
+ output_file=output_file,
519
+ file_lock=file_lock,
520
+ sample_size=sample_size,
521
+ model=model,
522
+ max_retries=0 # Retry
523
+ )
524
+ return success_count
525
+ except Exception as e:
526
+ print(f"Task {task_id} failed: {e}")
527
+ return 0
528
+
529
+ # Use thread pool and progress bar
530
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
531
+ # Submit all tasks
532
+ futures = {executor.submit(worker_task, i): i for i in range(num_iterations)}
533
+
534
+ # Use tqdm to show progress
535
+ with tqdm(total=num_iterations, desc="Generation progress", unit="batch") as pbar:
536
+ for future in as_completed(futures):
537
+ success_count = future.result()
538
+
539
+ with generated_lock:
540
+ total_generated += success_count
541
+
542
+ # Get current statistics
543
+ stats = index_pool.get_stats()
544
+
545
+ # Update progress bar
546
+ pbar.set_postfix({
547
+ 'Batch': f'{success_count}/5',
548
+ 'Total': total_generated,
549
+ 'Remaining': stats['available'],
550
+ 'Resets': stats['reset_count']
551
+ })
552
+ pbar.update(1)
553
+
554
+ # Final statistics
555
+ stats = index_pool.get_stats()
556
+ print(f"\nGeneration complete!")
557
+ print(f"Total generated: {total_generated} songs")
558
+ print(f"Used {stats['selected']} indices")
559
+ print(f"Remaining available indices: {stats['available']}")
560
+ print(f"Pool reset count: {stats['reset_count']}")
561
+
562
+
563
+ if __name__ == '__main__':
564
+ input_file = 'tagged_musics.jsonl'
565
+ output_file = 'generate_en_lrc.jsonl'
566
+ selected_file = 'selected.txt'
567
+ # n=1, max_retries=0, sample 10 songs each time, generate 5 new songs
568
+ batch_generate_music(
569
+ input_file=input_file,
570
+ output_file=output_file,
571
+ selected_file=selected_file,
572
+ total_songs=100,
573
+ sample_size=4,
574
+ model='gpt-4o-mini',
575
+ num_threads=20 # Test with 1 thread first
576
+ )
577
+ # Append to txt file
data_pipeline/meta_process/convert_convs.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate multi-turn dialogue data, each turn contains lyric text and corresponding audio token slices
3
+ """
4
+
5
+ import json
6
+ import re
7
+ import os
8
+ import torch
9
+ from tqdm import tqdm
10
+ from my_tool import load_jsonl
11
+
12
+ TOKEN_PER_SECOND = 25 # Number of tokens per second of audio
13
+ NUM_ITEMS = 100000 # Process N items first
14
+
15
+ timestamp_pattern = re.compile(r"\[([0-9]{1,2}):([0-9]{1,2})(?:[.:]([0-9]{1,3}))?\]")
16
+
17
+ def _parse_lyric_with_timestamps(lyric: str):
18
+ """
19
+ Return [(start_time_s, text), ...] sorted by timestamp
20
+ """
21
+ result = []
22
+ for match in timestamp_pattern.finditer(lyric):
23
+ start_idx = match.end()
24
+ end_idx = lyric.find("[", start_idx)
25
+ text = lyric[start_idx:end_idx].strip() if end_idx != -1 else lyric[start_idx:].strip()
26
+ if not text:
27
+ continue
28
+ minute = int(match.group(1))
29
+ second = int(match.group(2))
30
+ ms = int(match.group(3)) if match.group(3) else 0
31
+ total_seconds = minute * 60 + second + ms / 1000
32
+ result.append((total_seconds, text))
33
+ return result
34
+
35
+ def _load_audio_tokens(pt_file):
36
+ """
37
+ Load MuCodec encoding of audio
38
+ """
39
+ audio_ids = torch.load(pt_file, map_location="cpu").squeeze().long()
40
+ return audio_ids
41
+
42
+ def _get_token_slice(audio_tokens, start_s, end_s):
43
+ """Split encoding by time segment"""
44
+ start_idx = int(start_s * TOKEN_PER_SECOND)
45
+ end_idx = int(end_s * TOKEN_PER_SECOND)
46
+ sliced = audio_tokens[start_idx:end_idx]
47
+ return "[SOA]" + "".join([f"<AUDIO_{i.item()}>" for i in sliced]) + "[EOA]"
48
+
49
+ def _process_item(item, pt_dir:str):
50
+ song_name = item.get("song") or item.get("name")
51
+ song_name = song_name.split('.mp3')[0] # For mucodec, remove extension
52
+ pt_file = os.path.join(pt_dir, f"{song_name}.pt")
53
+ if not os.path.exists(pt_file):
54
+ return None
55
+
56
+ audio_tokens = _load_audio_tokens(pt_file)
57
+ tlyric_ = item.get('tlyric', "")
58
+ lyric_ = item.get('lyric', "")
59
+ lyric = tlyric_ if len(tlyric_) > len(lyric_) else lyric_
60
+ lyrics_ts = _parse_lyric_with_timestamps(lyric)
61
+
62
+ if not lyrics_ts:
63
+ # Skip if no lyrics
64
+ return None
65
+
66
+ rounds = []
67
+
68
+ # First generate a system message containing song information
69
+ intro_text = (
70
+ f"请生成一首歌曲,歌名为《{item.get('name', '')}》,风格是{item.get('style','')}"
71
+ f",情绪为{item.get('emotion','')},节奏:{item.get('rhythm','')},"
72
+ f"{item.get('description','')},由{item.get('singer','')}演唱,语言:{item.get('lang','')}。"
73
+ f"歌词如下:" + " ".join([text for _, text in lyrics_ts]) + "接下来我会逐句告诉你需要生成歌曲片段的歌词,\n请先生成前奏"
74
+ )
75
+ rounds.append({"role": "user", "content": intro_text})
76
+ rounds.append({"role": "assistant", "content": _get_token_slice(audio_tokens, 0, lyrics_ts[0][0])}) # Intro tokens
77
+
78
+ # Each lyric line corresponds to one round
79
+ for idx, (start_s, text) in enumerate(lyrics_ts[:-1]): ## Last line handled separately
80
+ end_s = lyrics_ts[idx + 1][0] if idx + 1 < len(lyrics_ts) else len(audio_tokens)/TOKEN_PER_SECOND # Last line to end of audio
81
+ rounds.append({"role": "user", "content": text})
82
+ rounds.append({"role": "assistant", "content": _get_token_slice(audio_tokens, start_s, end_s)})
83
+
84
+ # Tail processing logic
85
+ rounds.append({"role": "user", "content": f"请生成歌词{lyrics_ts[-1][1]}以及歌曲结尾"})
86
+ rounds.append({"role": "assistant", "content": _get_token_slice(audio_tokens, lyrics_ts[-1][0], len(audio_tokens)/TOKEN_PER_SECOND)})
87
+
88
+ return rounds
89
+
90
+ # ===== External Interface =====
91
+
92
+ def get_convert_convs(dataset:list[dict], pt_dir:str, save_path:str):
93
+ with open(save_path, "w", encoding="utf-8") as fout:
94
+ for item in tqdm(dataset, desc="Converting convs"):
95
+ rounds = _process_item(item, pt_dir)
96
+ if not rounds:
97
+ continue
98
+ fout.write(json.dumps({"messages": rounds}, ensure_ascii=False) + "\n")
data_pipeline/meta_process/convert_lyrics.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import copy
5
+ from tqdm import tqdm
6
+ from my_tool import dict_sort_print
7
+ from collections import defaultdict
8
+ from convert_convs import _parse_lyric_with_timestamps
9
+
10
+ # ===== Lyric Parsing =====
11
+
12
+ def _parse_lyrics(text:str) -> dict:
13
+ """Parse metadata, lyrics and timestamps from lyric information"""
14
+ segs = text.split("\n")
15
+ metadata = {
16
+ "lyrics_meta": {},
17
+ "lyrics": [],
18
+ "lyrics_time": [],
19
+ }
20
+ for seg in segs:
21
+ # Format: [time] metadata / lyrics
22
+ results = _parse_lyric_with_timestamps(seg)
23
+ for time, content in results:
24
+ if ":" in content or ":" in content:
25
+ # Metadata
26
+ pos1 = content.find(":")
27
+ pos2 = content.find(":")
28
+ pos = pos1 if pos1 != -1 else pos2
29
+ key = content[:pos].strip()
30
+ value = content[pos+1:].strip()
31
+ metadata["lyrics_meta"][key] = value
32
+ elif time == "00:00.00":
33
+ # Unstructured metadata at the beginning
34
+ continue
35
+ elif len(metadata['lyrics']) == 0 and "/" in content:
36
+ # Unstructured metadata at the beginning
37
+ continue
38
+ else:
39
+ # Only keep English and space punctuation
40
+ if len(content) == 0:
41
+ # Middle gap/end
42
+ if len(metadata['lyrics']) != 0 and metadata['lyrics'][-1] != "<nop>":
43
+ # If there's no previous segment (beginning), or previous segment is empty, don't record (merge)
44
+ metadata['lyrics'].append("<nop>")
45
+ metadata['lyrics_time'].append(time)
46
+ else:
47
+ if len(metadata['lyrics_time']) != 0 and metadata['lyrics_time'][-1] == time and time != "<nop>":
48
+ # Same timestamp means it's a translation (don't record)
49
+ continue
50
+ # Actual lyrics
51
+ metadata['lyrics'].append(content)
52
+ metadata['lyrics_time'].append(time)
53
+ return metadata
54
+
55
+ # ===== Language Detection =====
56
+
57
+ def _count_ch_nan(text:str):
58
+ """Count the number of Chinese and other non-English characters in a string"""
59
+ ch_num = 0
60
+ nan_num = 0
61
+ nan = ""
62
+ for c in text:
63
+ if '\u4e00' <= c <= '\u9fff':
64
+ ch_num += 1
65
+ elif ('a' <= c <= 'z') or ('A' <= c <= 'Z') or len(c.strip()) == 0:
66
+ continue
67
+ else:
68
+ nan_num += 1
69
+ nan += c
70
+ # if len(nan) > 0:
71
+ # print(nan)
72
+ return ch_num, nan_num
73
+
74
+ def _lang_decide(lyrics:list[str], val_limit:int=5, word_limit=3) -> str:
75
+ """
76
+ Determine the language type of lyrics (en/zh/ez/instrument/nan)
77
+ - val_limit: Only count if there are at least this many sentences
78
+ - word_limit: Only count if a sentence has at least this many words
79
+ """
80
+ ch_lyrics = 0
81
+ en_lyrics = 0
82
+ nan_lyrics = 0
83
+ for lyric in lyrics:
84
+ lyric = copy.deepcopy(lyric)
85
+ if lyric.strip() == "<nop>":
86
+ continue
87
+ lyric = re.sub(r"[''¥·′´(),。?""!@#$%^&*()?.'/,=+_—— !…《》<>0-9~※~;-・\"、☆|△【】#「」‖{}\[\]-]", " ", lyric)
88
+ ch_num, nan_num = _count_ch_nan(lyric)
89
+
90
+ if nan_num > word_limit:
91
+ nan_lyrics += 1
92
+ continue
93
+ elif ch_num > word_limit:
94
+ ch_lyrics += 1
95
+
96
+ lyric = re.sub(r'[\u4e00-\u9fff]+', '', lyric)
97
+ # Count English words by space separation
98
+ en_num = len(lyric.split(" "))
99
+ if en_num > word_limit:
100
+ en_lyrics += 1
101
+
102
+ if nan_lyrics > val_limit:
103
+ return "nan"
104
+ if ch_lyrics > val_limit and en_lyrics > val_limit:
105
+ return "ez"
106
+ if ch_lyrics > val_limit:
107
+ return "zh"
108
+ if en_lyrics > val_limit:
109
+ return "en"
110
+ return "instrument"
111
+
112
+ # ===== External Interface =====
113
+
114
+ def get_convert_lyrics(dataset:list[dict], save_path:str, dir:str, src_subfix:str=""):
115
+ """Convert lyrics and annotate language type (need to locate corresponding song)"""
116
+ new_dataset = []
117
+ lang_count = defaultdict(int)
118
+ unmatch = []
119
+ with open(save_path, 'w', encoding='utf-8') as file:
120
+ for ele in tqdm(dataset, desc="Converting Lyrics"):
121
+ ele = copy.deepcopy(ele)
122
+ # Skip if no lyrics
123
+ if not ele['has_lyric']:
124
+ # Don't add to final result
125
+ continue
126
+ # Get lyrics
127
+ lyric = ele['lyric']
128
+ if lyric == "":
129
+ lyric = ele['tlyric']
130
+
131
+ # Parse lyrics
132
+ new_data = _parse_lyrics(lyric)
133
+
134
+ # Language detection
135
+ lang = _lang_decide(new_data['lyrics'])
136
+ lang_count[lang] += 1
137
+
138
+ # Remove redundant fields
139
+ del ele['artists']
140
+ del ele['lyric']
141
+ del ele['tlyric']
142
+ del ele['has_lyric']
143
+
144
+ # Add new fields
145
+ ele['lyric_lang'] = lang
146
+ ele['source'] += src_subfix
147
+ for key, value in new_data.items():
148
+ ele[key] = value
149
+
150
+ new_dataset.append(ele)
151
+ json.dump(ele, file, ensure_ascii=False)
152
+ file.write("\n")
153
+
154
+ dict_sort_print(lang_count)
155
+ return new_dataset, unmatch
156
+
157
+ def get_match_music(music_data:list[dict], lyric_data:list[dict]):
158
+ """Get songs that match or don't match with lyrics"""
159
+ # 1. Build lookup set from songs
160
+ name_map = {}
161
+ for ele in tqdm(lyric_data, desc="Existing Lyrics"):
162
+ name = ele['name']
163
+ name = re.sub(" ", "", name)
164
+ artist = ele['artist']
165
+ complete_name = f"{name} - {artist}.mp3"
166
+ name_map[complete_name] = ele
167
+
168
+ # 2. Iterate through songs to find remaining ones
169
+ matches = []
170
+ unmatches = []
171
+ for ele in tqdm(music_data, desc="Check Matching"):
172
+ path = ele['path']
173
+ name = os.path.basename(path)
174
+ if name not in name_map:
175
+ unmatches.append(ele)
176
+ else:
177
+ meta = name_map[name]
178
+ meta['path'] = path
179
+ matches.append(meta)
180
+ return matches, unmatches
data_pipeline/meta_process/convert_messages.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate multi-turn dialogue data based on segment descriptions (segment-by-segment generation).
3
+
4
+ Input:
5
+ - MuseData/data/meta_suno_cn.jsonl # Base meta (contains src_path, lyrics, lang, tag, etc.)
6
+ - 3_block_80000_cn_desc.jsonl # Segment descriptions, contains startS/endS/text/desc for each section
7
+ - mucodec pt directory: same as PT_DIR_CN in multi_data_suno.py
8
+
9
+ Output:
10
+ - MuseData/sft_dataset_suno_cn.jsonl # Multi-turn dialogues generated segment by segment
11
+
12
+ Dialogue format example (refer to temp1.jsonl):
13
+ - First user: Summary prompt (Chinese), explains "segment-by-segment" + provides [Intro dsec...] description
14
+ - First assistant: Intro tokens (0 ~ first section's startS)
15
+ - Subsequent segments:
16
+ user.content = "[{Section} dsec]{desc}\\n{text}"
17
+ assistant.content = corresponding time segment tokens
18
+ """
19
+
20
+ import json
21
+ import os
22
+ import re
23
+ from typing import Dict, List, Tuple, Optional
24
+
25
+ import torch
26
+ from tqdm import tqdm
27
+ from my_tool import load_jsonl, clean_newlines
28
+
29
+ # Language configuration
30
+ LANG = "en"
31
+ # Path configuration
32
+ META_FILE = f"meta_suno_{LANG}.jsonl"
33
+ # desc folder, read all *.jsonl files in it
34
+ DESC_DIR = "desc"
35
+ PT_DIR = f"suno_mucodec_{LANG}"
36
+ # Output directory (different from meta directory), each desc file generates a set of three files
37
+ OUTPUT_DIR = "outputs"
38
+ OUTPUT_BASENAME = "minus_phonemes"
39
+
40
+ TOKEN_PER_SECOND = 25
41
+
42
+
43
+ LOG_FILE = os.path.join(OUTPUT_DIR, "section_mismatch_cn.log") # Place in output directory
44
+
45
+ def _log_warning(msg: str):
46
+ """Print and save to file"""
47
+ print(msg, end='\n')
48
+ with open(LOG_FILE, 'a', encoding='utf-8') as f:
49
+ f.write(msg + '\n')
50
+
51
+
52
+ # Timestamp parsing regex
53
+ timestamp_pattern = re.compile(
54
+ r"\[([0-9]{1,2}):([0-9]{1,2})(?:[.:]([0-9]{1,3}))?\]"
55
+ )
56
+
57
+
58
+ def load_pt(pt_file: str) -> torch.Tensor:
59
+ return torch.load(pt_file, map_location="cpu").squeeze().long()
60
+
61
+
62
+ def get_token_slice(audio: torch.Tensor, start_s: float, end_s: float) -> str:
63
+ if start_s < 0:
64
+ start_s = 0
65
+ if end_s < 0:
66
+ end_s = 0
67
+ s_idx = int(start_s * TOKEN_PER_SECOND)
68
+ e_idx = int(end_s * TOKEN_PER_SECOND)
69
+ s_idx = max(0, min(s_idx, audio.shape[0]))
70
+ e_idx = max(0, min(e_idx, audio.shape[0]))
71
+ if e_idx <= s_idx:
72
+ sliced = []
73
+ else:
74
+ sliced = audio[s_idx:e_idx]
75
+ return "[SOA]" + "".join(f"<AUDIO_{int(i)}>" for i in sliced) + "[EOA]"
76
+
77
+
78
+ def infer_pt_path(src_path: str) -> Optional[str]:
79
+ if not src_path:
80
+ return None
81
+ stem = os.path.splitext(os.path.basename(src_path))[0]
82
+ return os.path.join(PT_DIR, f"{stem}.pt")
83
+
84
+
85
+ def parse_lyric_with_timestamps(lyric: str) -> List[Tuple[float, str]]:
86
+ """
87
+ Parse [(start_time_s, text), ...] from lyrics with timestamps, sorted by time ascending.
88
+ The returned timestamps come from the lyrics field in the meta file.
89
+ """
90
+ result: List[Tuple[float, str]] = []
91
+ matches = list(timestamp_pattern.finditer(lyric))
92
+
93
+ for i, match in enumerate(matches):
94
+ start_idx = match.end()
95
+ if i + 1 < len(matches):
96
+ end_idx = matches[i + 1].start()
97
+ else:
98
+ end_idx = len(lyric)
99
+
100
+ text = lyric[start_idx:end_idx].strip()
101
+ minutes = int(match.group(1))
102
+ seconds = int(match.group(2))
103
+ ms_str = match.group(3) if match.group(3) else "0"
104
+
105
+ if len(ms_str) == 2:
106
+ fractional_seconds = int(ms_str) / 100.0
107
+ elif len(ms_str) == 3:
108
+ fractional_seconds = int(ms_str) / 1000.0
109
+ else:
110
+ fractional_seconds = int(ms_str) / 1000.0 if ms_str else 0.0
111
+
112
+ total_seconds = minutes * 60 + seconds + fractional_seconds
113
+ result.append((total_seconds, text))
114
+ return result
115
+
116
+
117
+ def extract_section_from_text(text: str) -> Optional[str]:
118
+ """
119
+ Relaxed rule: As long as [] contains English words (≥2 letters), return the entire bracket content as-is.
120
+ """
121
+ # Match if English words appear in the first []
122
+ m = re.search(r'\[([A-Za-z][A-Za-z0-9\s\-\(\)]*)\]', text)
123
+ if m:
124
+ return m.group(1).strip() # Remove leading and trailing spaces
125
+ return None
126
+
127
+ def format_section_label(sec_name: str) -> str:
128
+ """Keep original spaces, only trim leading and trailing whitespace."""
129
+ return sec_name.strip()
130
+
131
+
132
+ def normalize_section_name(sec_name: str) -> str:
133
+ """
134
+ Normalize section name for matching:
135
+ - Remove all spaces
136
+ - Convert to lowercase
137
+ - Remove trailing digits (if any)
138
+ """
139
+ # Remove all spaces
140
+ normalized = sec_name.replace(" ", "").lower()
141
+ # Remove trailing digits (e.g., "verse1" -> "verse", "chorus1" -> "chorus")
142
+ normalized = re.sub(r"\d+$", "", normalized)
143
+ return normalized
144
+
145
+
146
+ def clean_desc(desc: str) -> str:
147
+ """
148
+ Clean desc field:
149
+ 1. If starts with [desc], remove it
150
+ 2. If both ends are brackets, remove brackets
151
+ """
152
+ if not desc:
153
+ return desc
154
+
155
+ desc = desc.strip()
156
+
157
+ # If starts with [desc], remove it
158
+ if desc.startswith("[desc]"):
159
+ desc = desc[6:].strip()
160
+
161
+ # If both ends are brackets, remove brackets
162
+ if desc.startswith("[") and desc.endswith("]"):
163
+ desc = desc[1:-1].strip()
164
+
165
+ return desc
166
+
167
+
168
+ def build_desc_map(desc_path_or_dir: str) -> Dict[Tuple[str, int], List[dict]]:
169
+ """
170
+ Support passing a single jsonl file or a directory containing multiple jsonl files.
171
+ In directory scenario, files are sorted by name and read sequentially, later records with the same key will overwrite earlier ones.
172
+ """
173
+ mapping: Dict[Tuple[str, int], List[dict]] = {}
174
+
175
+ paths: List[str] = []
176
+ if os.path.isdir(desc_path_or_dir):
177
+ for name in sorted(os.listdir(desc_path_or_dir)):
178
+ if name.endswith(".jsonl"):
179
+ paths.append(os.path.join(desc_path_or_dir, name))
180
+ else:
181
+ paths.append(desc_path_or_dir)
182
+
183
+ for path in paths:
184
+ with open(path, "r", encoding="utf-8") as f:
185
+ for line in f:
186
+ try:
187
+ obj = json.loads(line)
188
+ except Exception:
189
+ with open("error.txt", 'a', encoding='utf-8') as error_file:
190
+ error_file.write(line + "\n")
191
+ continue
192
+
193
+ song_id = obj.get("song_id")
194
+ track_idx = obj.get("track_index", 0)
195
+
196
+ # Change: Put the entire object
197
+ # sections = obj.get("sections", [])
198
+ # mapping[(song_id, track_idx)] = sections
199
+ mapping[(song_id, track_idx)] = obj
200
+ return mapping
201
+
202
+
203
+ def extract_suffix_from_desc_path(desc_path: str, fallback_idx: int) -> str:
204
+ """
205
+ Extract suffix from desc filename for output file naming.
206
+ Rule: Try to extract <suffix> from filename pattern "3_block_<suffix>_(cn|en)_desc.jsonl".
207
+ If no match, use fallback_idx (starting from 0) converted to string.
208
+ """
209
+ fname = os.path.basename(desc_path)
210
+ m = re.search(r"3_block_([^_]+)_(?:cn|en)_desc\.jsonl", fname, re.IGNORECASE)
211
+ if m:
212
+ return m.group(1)
213
+ return str(fallback_idx)
214
+
215
+
216
+ def extract_suffix_num(desc_path: str, fallback_idx: int) -> int:
217
+ """
218
+ Extract sortable numeric suffix for processing desc files in numeric order.
219
+ Use fallback_idx if parsing fails.
220
+ """
221
+ suffix = extract_suffix_from_desc_path(desc_path, fallback_idx)
222
+ try:
223
+ return int(suffix)
224
+ except ValueError:
225
+ return fallback_idx
226
+
227
+
228
+ def build_messages(item: dict, obj: dict, audio: torch.Tensor) -> Optional[dict]:
229
+ """
230
+ Use timestamps from meta to split audio, desc_sections provide desc and text.
231
+ """
232
+ if not obj:
233
+ return None
234
+ desc_sections = obj.get("sections", [])
235
+
236
+ # Parse timestamps from meta's lyrics
237
+ lyrics_raw = item.get("lyrics", "") or ""
238
+ meta_ts_list = parse_lyric_with_timestamps(lyrics_raw)
239
+ if not meta_ts_list:
240
+ return None
241
+
242
+ total_seconds = audio.shape[0] / float(TOKEN_PER_SECOND)
243
+
244
+ # Sort desc_sections (by startS, for matching)
245
+ desc_sections = sorted(desc_sections, key=lambda x: x.get("startS", 0.0))
246
+
247
+ # Get middle sections from desc_sections (skip first Intro and last Outro)
248
+ # Intro and Outro are not in meta's lyrics, need separate handling
249
+ # Middle sections are matched in order
250
+ middle_desc_sections = desc_sections[1:-1] if len(desc_sections) > 2 else desc_sections[1:] if len(desc_sections) > 1 else []
251
+
252
+ # Identify sections from meta timestamps and build mapping
253
+ # One section may contain multiple lyric lines, need to merge
254
+ section_timestamps: List[Tuple[str, float, float, str, str]] = [] # (section_name, start_s, end_s, text, desc)
255
+ current_section: Optional[Tuple[str, float, str]] = None # (section_name, start_s, accumulated_text)
256
+ desc_idx = 0 # For matching desc in order (from middle_desc_sections)
257
+
258
+ for idx, (start_s, text) in enumerate(meta_ts_list):
259
+ # Extract section name
260
+ section_name = extract_section_from_text(text)
261
+
262
+ if section_name:
263
+ # Encountered new section label
264
+ # First save previous section (if any)
265
+ if current_section:
266
+ # Determine end time of previous section (current timestamp)
267
+ prev_sec_name, prev_start_s, prev_text = current_section
268
+ # Only remove timestamp, keep all other content (including line breaks, section labels, etc.)
269
+ clean_prev_text = re.sub(r"\[([0-9]{1,2}):([0-9]{1,2})(?:[.:]([0-9]{1,3}))?\]", "", prev_text)
270
+
271
+ # Get desc in order (match from middle sections in order)
272
+ prev_desc = ""
273
+ if desc_idx < len(middle_desc_sections):
274
+ prev_desc = clean_desc(middle_desc_sections[desc_idx].get("desc", ""))
275
+ desc_idx += 1
276
+
277
+ section_timestamps.append((prev_sec_name, prev_start_s, start_s, clean_prev_text, prev_desc))
278
+
279
+ # Start new section
280
+ current_section = (section_name, start_s, text)
281
+ else:
282
+ # No section label, belongs to subsequent lines of current section
283
+ if current_section:
284
+ sec_name, sec_start, sec_text = current_section
285
+ # Preserve line breaks, connect with line breaks
286
+ current_section = (sec_name, sec_start, sec_text + "\n" + text)
287
+ # If no current_section, skip (might be empty line before Intro)
288
+
289
+ # Process last section
290
+ # Check if last timestamp is empty text (indicates end marker)
291
+ outro_start_s: Optional[float] = None
292
+ if meta_ts_list and not meta_ts_list[-1][1].strip():
293
+ # Last timestamp is empty text, indicates end marker
294
+ outro_start_s = meta_ts_list[-1][0]
295
+
296
+ if current_section:
297
+ sec_name, sec_start, sec_text = current_section
298
+ # If last timestamp is empty text, last section's end time should be this timestamp
299
+ # Otherwise use total duration
300
+ if outro_start_s is not None:
301
+ end_s = outro_start_s
302
+ else:
303
+ end_s = total_seconds
304
+
305
+ # Only remove timestamp, keep all other content
306
+ clean_text = re.sub(r"\[([0-9]{1,2}):([0-9]{1,2})(?:[.:]([0-9]{1,3}))?\]", "", sec_text)
307
+
308
+ # Get desc in order (match from middle sections in order)
309
+ desc = ""
310
+ if desc_idx < len(middle_desc_sections):
311
+ desc = clean_desc(middle_desc_sections[desc_idx].get("desc", ""))
312
+ desc_idx += 1
313
+
314
+ section_timestamps.append((sec_name, sec_start, end_s, clean_text, desc))
315
+
316
+ # Check if counts match (only check middle sections, excluding Intro and Outro)
317
+ if desc_idx != len(middle_desc_sections):
318
+ _log_warning(f"⚠️ Warning: Section count mismatch! meta has {len(section_timestamps)} sections, desc has {len(middle_desc_sections)} middle sections (excluding Intro and Outro) (song_id: {item.get('song_id')}, track_index: {item.get('track_index')})")
319
+
320
+ if not section_timestamps:
321
+ return None
322
+
323
+ # Intro segment: from 0 to first section's start time
324
+ # Intro's desc should have been obtained in sequential matching, but Intro itself is not in meta's lyrics
325
+ # So need to get from desc_sections' first section (usually Intro)
326
+ first_section_start = section_timestamps[0][1] if section_timestamps else total_seconds
327
+ intro_desc = ""
328
+ if desc_sections and desc_sections[0].get("section", "").lower() == "intro":
329
+ intro_desc = clean_desc(desc_sections[0].get("desc", ""))
330
+
331
+ # Change: Use desc tag
332
+ # tag = item.get("tag", "")
333
+ song_id:str = obj.get("song_id", "")
334
+ omni_tag = obj.get("omni", "")
335
+ style_tag = obj.get("style", "")
336
+
337
+ if song_id.find("cn") != -1:
338
+ # Chinese songs use omni directly
339
+ tag = omni_tag
340
+ else:
341
+ # English songs compare omni / style
342
+ style_sim = obj.get("style_sim", 0)
343
+ omni_sim = obj.get("omni_sim", 0)
344
+ try:
345
+ tag = omni_tag if omni_sim > style_sim else style_tag
346
+ except Exception as e:
347
+ # If sim score is invalid, default to omni
348
+ tag = omni_tag
349
+ print(f"Error: {song_id}, {e}")
350
+
351
+ # Change: English
352
+ intro_prompt = (
353
+ f"Please generate a song in the following style:{tag}.\n"
354
+ "Next, I will tell you the requirements and lyrics for the song fragment to be generated, section by section.\n"
355
+ f"[Intro][desc:{intro_desc}]"
356
+ )
357
+
358
+ messages: List[dict] = []
359
+ messages.append({"role": "user", "content": intro_prompt})
360
+ messages.append(
361
+ {
362
+ "role": "assistant",
363
+ "content": get_token_slice(audio, 0.0, first_section_start),
364
+ }
365
+ )
366
+
367
+ # Process segment by segment (using timestamps from meta)
368
+ for idx, (sec_name, start_s, end_s, text, desc) in enumerate(section_timestamps):
369
+ # user content: [Section dsec : desc][Section lyrics : ...] (preserve original spaces)
370
+ label = format_section_label(sec_name)
371
+ content = f"[{label}]"
372
+ content += f"[desc:{desc}]"
373
+ lyrics_text = re.sub(r'^\[.*?\]\s*\n?', '', text.strip())
374
+ if lyrics_text:
375
+ content += f"[lyrics:\n{lyrics_text}]"
376
+ messages.append({"role": "user", "content": content})
377
+ messages.append(
378
+ {
379
+ "role": "assistant",
380
+ "content": get_token_slice(audio, start_s, end_s),
381
+ }
382
+ )
383
+
384
+ # If last timestamp is empty text, add Outro segment
385
+ # Outro's desc should be obtained from desc_sections' last section (usually Outro)
386
+ if outro_start_s is not None and outro_start_s < total_seconds:
387
+ outro_desc = ""
388
+ if desc_sections and desc_sections[-1].get("section", "").lower() == "outro":
389
+ outro_desc = clean_desc(desc_sections[-1].get("desc", ""))
390
+
391
+ messages.append({"role": "user", "content": f"[Outro][desc:{outro_desc}]"})
392
+ messages.append(
393
+ {
394
+ "role": "assistant",
395
+ "content": get_token_slice(audio, outro_start_s, total_seconds),
396
+ }
397
+ )
398
+
399
+ sample = {
400
+ "song_id": item.get("song_id"),
401
+ "track_index": item.get("track_index"),
402
+ "src_path": item.get("src_path"),
403
+ "tag": item.get("tag"),
404
+ "lang": item.get("lang"),
405
+ "duration": item.get("duration"),
406
+ "messages": messages,
407
+ }
408
+ return sample
409
+
410
+
411
+ def process_with_desc(desc_path: str, suffix: str) -> None:
412
+ """
413
+ Generate output files using a single desc file (messages-only / meta-only).
414
+ Does not generate main file.
415
+ """
416
+ desc_map = build_desc_map(desc_path)
417
+
418
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
419
+ # messages-only naming: remove suffix, only keep block number (e.g., ..._8000.jsonl)
420
+ out_msg = os.path.join(OUTPUT_DIR, f"{OUTPUT_BASENAME}_{suffix}.jsonl")
421
+ out_meta = os.path.join(OUTPUT_DIR, f"{OUTPUT_BASENAME}_{suffix}_meta_only.jsonl")
422
+
423
+ dataset = load_jsonl(META_FILE)
424
+ total = len(dataset)
425
+ kept = 0
426
+ skipped = 0
427
+
428
+ for path in [out_msg, out_meta]:
429
+ if os.path.exists(path):
430
+ # Change: use past
431
+ already_msg = load_jsonl(out_msg)
432
+ already_meta = load_jsonl(out_meta)
433
+ assert len(already_meta) == len(already_msg)
434
+ dataset = dataset[len(already_msg):]
435
+
436
+ with open(out_msg, "a", encoding="utf-8") as fout_msg, \
437
+ open(out_meta, "a", encoding="utf-8") as fout_meta:
438
+
439
+ for item in tqdm(dataset, desc=f"Processing meta_suno_{LANG}.jsonl (desc: {suffix})"):
440
+ key = (item.get("song_id"), item.get("track_index", 0))
441
+
442
+ obj = desc_map.get(key)
443
+
444
+ if not obj:
445
+ skipped += 1
446
+ continue
447
+
448
+ pt_path = infer_pt_path(item.get("src_path", ""))
449
+ if not pt_path or not os.path.exists(pt_path):
450
+ skipped += 1
451
+ continue
452
+
453
+ audio = load_pt(pt_path)
454
+ sample = build_messages(item, obj, audio)
455
+ if not sample:
456
+ skipped += 1
457
+ continue
458
+
459
+ # Write messages-only (remove _messages_only suffix)
460
+ messages_only = {"messages": sample.get("messages", [])}
461
+ fout_msg.write(json.dumps(messages_only, ensure_ascii=False) + "\n")
462
+ # Write meta-only
463
+ meta_only = {k: v for k, v in sample.items() if k != "messages"}
464
+ fout_meta.write(json.dumps(meta_only, ensure_ascii=False) + "\n")
465
+ kept += 1
466
+
467
+ print(f"✅ messages-only: {out_msg}")
468
+ print(f"✅ meta-only: {out_meta}")
469
+ print(f"Total {total}, kept {kept}, skipped {skipped}")
470
+
471
+
472
+ def convert_train_valid():
473
+ # Collect desc files
474
+ if not os.path.isdir(DESC_DIR):
475
+ print(f"⚠️ DESC_DIR does not exist: {DESC_DIR}")
476
+ return
477
+
478
+ unsorted_files = [
479
+ os.path.join(DESC_DIR, name)
480
+ for name in os.listdir(DESC_DIR)
481
+ if name.endswith(".jsonl")
482
+ ]
483
+
484
+ # Sort by extracted numeric suffix, ensure 8000 comes before 16000
485
+ desc_files = sorted(
486
+ unsorted_files,
487
+ key=lambda p: extract_suffix_num(p, 0),
488
+ )
489
+ if not desc_files:
490
+ print(f"⚠️ No desc files found: {DESC_DIR}")
491
+ return
492
+
493
+ # Change
494
+ # for idx, desc_path in enumerate(desc_files):
495
+ # suffix = extract_suffix_from_desc_path(desc_path, idx)
496
+ # process_with_desc(desc_path, suffix)
497
+
498
+ for desc_path in desc_files:
499
+ name = os.path.splitext(os.path.basename(desc_path))[0]
500
+ if name.endswith(LANG):
501
+ process_with_desc(desc_path, name)
502
+
503
+ # assistant needs, but content is empty
504
+ # Then each section inside has 3 descs, you take the one corresponding to the maximum value
505
+ # omni also has two, take the larger one (style is not used)
506
+
507
+ from meta_phonemes import _get_lyrics, _trans_sentences
508
+
509
+ def _form_section(section:dict, en:bool) -> str:
510
+ """Process inside a section, determine which desc to select"""
511
+ # Segment label
512
+ section_tag = f"[{section['section']}]"
513
+ # Segment description
514
+ descs = [section['desc1'], section['desc2'], section['desc3']]
515
+ sims = [section['desc1_sim'], section['desc2_sim'], section['desc3_sim']]
516
+ max_sim = max(sims)
517
+ max_index = sims.index(max_sim)
518
+ desc:str = descs[max_index]
519
+
520
+ if desc == "音频过短": # "Audio too short" - keep original Chinese as it's part of data processing logic
521
+ desc = "[desc:]"
522
+ else:
523
+ DESC_START = "[desc] "
524
+ if desc.startswith(DESC_START):
525
+ desc = desc[len(DESC_START):] + "]"
526
+ desc = "[desc:" + desc[1:]
527
+
528
+ # Lyrics & phonemes
529
+ text:str = section['text']
530
+ if text.find(']') != -1:
531
+ # Remove preceding segment label
532
+ start = text.rfind(']')
533
+ text = text[start+1:]
534
+ if len(text.strip()) == 0:
535
+ # Opening segment has no lyrics/phonemes
536
+ lyrics = ""
537
+ phonemes = ""
538
+ else:
539
+ if en:
540
+ lyrics = "[lyrics:\n" + clean_newlines(text) + "]"
541
+ else:
542
+ lyrics = "[lyrics:" + text + "]"
543
+
544
+ sentences, lyrics = _get_lyrics(lyrics)
545
+ phonemes = _trans_sentences(sentences)
546
+ return section_tag + desc + lyrics + phonemes
547
+
548
+ def _form_intro(ele:dict) -> str:
549
+ """Process to get multi-turn dialogue opening"""
550
+ omni1 = ele['omni1']
551
+ omni2 = ele['omni2']
552
+ omni_sim1 = ele['omni1_sim']
553
+ omni_sim2 = ele['omni2_sim']
554
+ tag = omni1 if omni_sim1 > omni_sim2 else omni2
555
+ return (
556
+ f"Please generate a song in the following style:{tag}.\n"
557
+ "Next, I will tell you the requirements and lyrics for the song fragment to be generated, section by section.\n"
558
+ )
559
+
560
+ def convert_test():
561
+ path = "filter.jsonl"
562
+ dataset = load_jsonl(path)
563
+ save_path = "messages.jsonl"
564
+ with open(save_path, 'w', encoding='utf-8') as file:
565
+ for ele in tqdm(dataset, desc=f"Converting {path}"):
566
+ messages = []
567
+ # Segment processing
568
+ sections = ele['sections']
569
+ id:str = ele['song_id']
570
+ english = id.startswith("suno_test_en")
571
+ for section in sections:
572
+ content = _form_section(section, english)
573
+ messages += [
574
+ {
575
+ "role": "user",
576
+ "content": content
577
+ },
578
+ {
579
+ "role": "assistant",
580
+ "content": ""
581
+ }
582
+ ]
583
+ # Initial addition
584
+ first_content = messages[0]['content']
585
+ intro = _form_intro(ele)
586
+ messages[0]['content'] = intro + first_content
587
+
588
+ data = {"messages": messages}
589
+ json.dump(data, file, ensure_ascii=False)
590
+ file.write("\n")
591
+
592
+ if __name__ == "__main__":
593
+ convert_test()
data_pipeline/meta_process/convert_segments.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from tqdm import tqdm
4
+ from my_tool import path_join, load_json
5
+ from concurrent.futures import ProcessPoolExecutor, as_completed
6
+
7
+ def _check_label(label:str, max_length:int=30) -> bool:
8
+ """Check if label is valid (non-empty, not timestamp, not long lyrics)"""
9
+ length = len(label.strip())
10
+ if length == 0:
11
+ # print("Error Label: Empty")
12
+ return False
13
+ if length > max_length:
14
+ # print(f"Error Label: Words - {label}")
15
+ return False
16
+ if label.find(":") != -1 and label.find(".") != -1:
17
+ # Considered as timestamp
18
+ # print(f"Error Label: Timestamp - {label}")
19
+ return False
20
+ return True
21
+
22
+ def _convert_one(path:str):
23
+ """Segment a song's metadata, remove redundant content"""
24
+ data = load_json(path)
25
+ dir = os.path.dirname(path)
26
+ name = f"{data['song_id']}_{data['track_index']}.mp3"
27
+ path = path_join(dir, name)
28
+ new_data = {
29
+ "path": path,
30
+ "song_id": data['song_id'],
31
+ "segments": []
32
+ }
33
+ words_info = data['timestamped_lyrics']['alignedWords'] # Sentence-by-sentence information
34
+ seg_info = None
35
+
36
+ empty_head = False
37
+ for id, word_info in enumerate(words_info):
38
+ if not word_info['success']:
39
+ continue
40
+ word:str = word_info['word']
41
+
42
+ label = ""
43
+ if word.startswith('['):
44
+ if seg_info is not None:
45
+ new_data['segments'].append(seg_info)
46
+ label_end = word.find(']')
47
+ label = word[1:label_end]
48
+ if not _check_label(label):
49
+ label = ""
50
+
51
+ if label != "":
52
+ seg_info = {
53
+ "start": word_info['startS'],
54
+ "end": 0,
55
+ "label": label,
56
+ "word": word[label_end+2:]
57
+ }
58
+ elif seg_info is not None:
59
+ seg_info['end'] = word_info['endS']
60
+ seg_info['word'] += word
61
+ else:
62
+ empty_head = True
63
+ if seg_info is not None:
64
+ seg_info['end'] = word_info['endS']
65
+ seg_info['word'] += word
66
+ else:
67
+ empty_head = True
68
+ if empty_head:
69
+ # print(f"Empty Head, segment: {len(new_data['segments'])}, path: {path}")
70
+ pass
71
+ return new_data
72
+
73
+ # ===== External Interface =====
74
+
75
+ def get_convert_segments(data_dir:str, save_path:str, max_workers:int=10):
76
+ paths = []
77
+ for name in tqdm(os.listdir(data_dir), desc="Getting the JSON Paths"):
78
+ if name.endswith(".json"):
79
+ path = path_join(data_dir, name)
80
+ paths.append(path)
81
+
82
+ dataset = []
83
+ with open(save_path, 'w', encoding='utf-8') as file:
84
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
85
+ futures = [executor.submit(_convert_one, path) for path in paths]
86
+ with tqdm(total=len(futures), desc="Converting Segments") as pbar:
87
+ for future in as_completed(futures):
88
+ result = future.result()
89
+ dataset.append(result)
90
+ json.dump(result, file, ensure_ascii=False)
91
+ file.write("\n")
92
+ pbar.update(1)
93
+ return dataset
data_pipeline/meta_process/evaluate_polyphones.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import jieba
3
+ from tqdm import tqdm
4
+ from my_tool import load_jsonl, save_json, load_json
5
+ from pypinyin import pinyin, Style, load_phrases_dict
6
+ from pypinyin_dict.phrase_pinyin_data import cc_cedict
7
+
8
+ cc_cedict.load()
9
+ re_special_pinyin = re.compile(r'^(n|ng|m)$')
10
+
11
+ # Add
12
+ reference = load_json("poly_correct.json")
13
+ load_phrases_dict(reference)
14
+
15
+ def _filter(dataset:list[dict]):
16
+ """Filter non-polyphone characters in test set"""
17
+ new_dataset = []
18
+ for ele in tqdm(dataset, desc="Filtering"):
19
+ pos = ele['pos']
20
+ sentence = ele['sentence']
21
+ word = sentence[pos]
22
+ phones = pinyin(word, style=Style.NORMAL, heteronym=True)[0]
23
+ if len(phones) > 1:
24
+ new_dataset.append(ele)
25
+ print(f"Filter non polyphone, {len(dataset)} -> {len(new_dataset)}")
26
+ return new_dataset
27
+
28
+ def evaluate_polyphones(dataset:list[dict], save_fail:str):
29
+ """Check pinyin processing accuracy for polyphones"""
30
+ dataset = _filter(dataset)
31
+ total = len(dataset)
32
+ right = 0
33
+ correct_dic = {}
34
+ for ele in tqdm(dataset):
35
+ pos = ele['pos']
36
+ phone = ele['phone']
37
+ sentence = ele['sentence']
38
+ seg_list = jieba.cut(sentence)
39
+ length = 0
40
+ for seg in seg_list:
41
+ if length <= pos and length + len(seg) > pos:
42
+ delta = pos - length # Position in segment
43
+ break
44
+ length += len(seg)
45
+ pred_phones = pinyin(seg, style=Style.NORMAL)
46
+ pred_phone = pred_phones[delta][0]
47
+ if pred_phone == phone or pred_phone.endswith("v"):
48
+ right += 1
49
+ elif len(pred_phones) > 1:
50
+ # Corrected pronunciation (only meaningful for phrases)
51
+ pred_phones[delta] = [phone]
52
+ correct_dic[seg] = pred_phones
53
+ print(f"Acc: {(right / total):.2f}")
54
+
55
+ origin_dic = load_json(save_fail)
56
+ merge_dic = origin_dic | correct_dic
57
+ save_json(merge_dic, save_fail)
58
+
59
+ if __name__ == "__main__":
60
+ path = "polyphones.jsonl"
61
+ dataset = load_jsonl(path)
62
+ evaluate_polyphones(dataset, "poly_correct.json")
data_pipeline/meta_process/filter.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ from tqdm import tqdm
3
+ from concurrent.futures import ProcessPoolExecutor, as_completed
4
+
5
+ def filter_lang(dataset:list[dict], langs:list[str]) -> list[dict]:
6
+ """Filter dataset, only keep items with lang tag matching specified languages"""
7
+ new_dataset = []
8
+ for ele in tqdm(dataset, desc="Filtering Lang"):
9
+ if 'lang' not in ele or ele['lang'] not in langs :
10
+ continue
11
+ new_dataset.append(ele)
12
+ print(f"filter: {len(dataset)} -> {len(new_dataset)}")
13
+ return new_dataset
14
+
15
+ def _check_duration(ele, lower_bound, upper_bound):
16
+ """Subprocess task: Check if audio duration is within range"""
17
+ duration = librosa.get_duration(filename=ele['path'])
18
+ if lower_bound != -1 and duration < lower_bound:
19
+ return None
20
+ if upper_bound != -1 and duration > upper_bound:
21
+ return None
22
+ return ele
23
+
24
+ def filter_length(dataset:list[dict], lower_bound:int=-1, upper_bound:int=-1, max_worker:int=4) -> list[dict]:
25
+ """Filter dataset, only keep items with length in [lower_bound, upper_bound], if set to -1 then no limit on that side"""
26
+ new_dataset = []
27
+ with ProcessPoolExecutor(max_workers=max_worker) as executor:
28
+ futures = [
29
+ executor.submit(_check_duration, ele, lower_bound, upper_bound)
30
+ for ele in dataset
31
+ ]
32
+ with tqdm(total=len(futures), desc="Filtering Length") as pbar:
33
+ for future in as_completed(futures):
34
+ result = future.result()
35
+ if result is not None:
36
+ new_dataset.append(result)
37
+ pbar.update(1)
38
+ # for ele in tqdm(dataset, desc="Filtering Length"):
39
+ # duration = librosa.get_duration(filename=ele['path'])
40
+ # if lower_bound != -1 and duration < lower_bound:
41
+ # continue
42
+ # if upper_bound != -1 and duration > upper_bound:
43
+ # continue
44
+ # new_dataset.append(ele)
45
+ print(f"filter: {len(dataset)} -> {len(new_dataset)}")
46
+ return new_dataset
data_pipeline/meta_process/main.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from my_tool import (
2
+ load_json,
3
+ load_jsonl,
4
+ load_txt,
5
+ save_jsonl,
6
+ format_meta,
7
+ pure_name,
8
+ BASE_DIR,
9
+ compose_analyze,
10
+ get_sample,
11
+ get_field_suno,
12
+ tags_analyze,
13
+ find_json,
14
+ show_dir,
15
+ convert_mp3,
16
+ tar_dir,
17
+ tar_size_check,
18
+ clean_newlines,
19
+ dict_sort_print,
20
+ )
21
+ from meta_lang import load_asr_model, get_lang_meta
22
+ from meta_tags import load_tag_model, get_tags_meta
23
+ from meta_endpoints import get_endpoints_meta
24
+ from meta_phonemes import get_phonemes_meta
25
+ from filter import filter_lang, filter_length
26
+ from convert_convs import get_convert_convs
27
+ from convert_segments import get_convert_segments
28
+ from convert_lyrics import get_convert_lyrics, get_match_music
29
+
30
+ def pipeline():
31
+ import os
32
+ dir = "suno_batch"
33
+ name = pure_name(dir)
34
+ save_dir = BASE_DIR / f"data/{name}"
35
+
36
+ # Initialize paths (only once)
37
+ os.makedirs(save_dir, exist_ok=True)
38
+ raw_path = os.path.join(save_dir, "raw.jsonl")
39
+ if os.path.exists(raw_path):
40
+ dataset = load_jsonl(raw_path)
41
+ else:
42
+ dataset = format_meta(dir)
43
+ save_jsonl(dataset, raw_path)
44
+
45
+ # Length filtering
46
+ dataset = dataset[:1000]
47
+ max_workers = 10
48
+ dataset = filter_length(dataset, 120, 360, max_workers)
49
+
50
+ # Language tagging
51
+ # lang_bs = 8
52
+ # model = load_asr_model(lang_bs)
53
+ # lang_path = os.path.join(save_dir, "meta_lang.jsonl")
54
+ # dataset = get_lang_meta(model, dataset, lang_bs, lang_path)
55
+
56
+ # Language filtering
57
+ # dataset = filter_lang(dataset, ['zh', 'en'])
58
+
59
+ # Style tagging
60
+ tag_bs = 4
61
+ tag_path = os.path.join(save_dir, "meta_tags.jsonl")
62
+ model, processor = load_tag_model()
63
+ prompt_path = BASE_DIR / "prompts/new_tags.md"
64
+ prompt = load_txt(prompt_path)
65
+ get_tags_meta(model, processor, dataset, prompt, tag_bs, tag_path)
66
+
67
+ def repeat(func):
68
+ while True:
69
+ try:
70
+ func()
71
+ break
72
+ except Exception as e:
73
+ print(f"Error: {e}")
74
+ continue
75
+
76
+ if __name__ == "__main__":
77
+ repeat(pipeline)
data_pipeline/meta_process/meta_endpoints.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import webrtcvad
3
+ import collections
4
+ from tqdm import tqdm
5
+ from my_tool import dup_remove
6
+ from pydub import AudioSegment
7
+ from concurrent.futures import ProcessPoolExecutor, as_completed
8
+
9
+ def _frame_generator(frame_duration_ms, audio, sample_rate):
10
+ """Split audio into frames"""
11
+ bytes_per_sample = 2
12
+ frame_size = int(sample_rate * frame_duration_ms / 1000.0) * bytes_per_sample
13
+ offset = 0
14
+ timestamp = 0.0
15
+ frame_duration = frame_duration_ms / 1000.0
16
+ while offset + frame_size < len(audio):
17
+ yield audio[offset:offset + frame_size], timestamp
18
+ timestamp += frame_duration
19
+ offset += frame_size
20
+
21
+ def _vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames):
22
+ """Merge continuous vocal segments based on webrtcvad"""
23
+ num_padding_frames = int(padding_duration_ms / frame_duration_ms)
24
+ ring_buffer = collections.deque(maxlen=num_padding_frames)
25
+
26
+ triggered = False
27
+ speech_segments = []
28
+
29
+ for frame_bytes, timestamp in frames:
30
+ is_speech = vad.is_speech(frame_bytes, sample_rate)
31
+
32
+ if not triggered:
33
+ ring_buffer.append((frame_bytes, timestamp, is_speech))
34
+ num_voiced = len([f for f in ring_buffer if f[2]])
35
+ if num_voiced > 0.9 * ring_buffer.maxlen:
36
+ triggered = True
37
+ start_time = ring_buffer[0][1]
38
+ ring_buffer.clear()
39
+ else:
40
+ ring_buffer.append((frame_bytes, timestamp, is_speech))
41
+ num_unvoiced = len([f for f in ring_buffer if not f[2]])
42
+ if num_unvoiced > 0.9 * ring_buffer.maxlen:
43
+ end_time = timestamp + (frame_duration_ms / 1000.0)
44
+ speech_segments.append((start_time, end_time))
45
+ triggered = False
46
+ ring_buffer.clear()
47
+
48
+ # If still in speech state at the end, close the last segment
49
+ if triggered:
50
+ end_time = timestamp + (frame_duration_ms / 1000.0)
51
+ speech_segments.append((start_time, end_time))
52
+
53
+ return speech_segments
54
+
55
+ def _one_process(path):
56
+ """Detect vocal segments in an audio"""
57
+ # 1. Compress audio
58
+ audio = AudioSegment.from_file(path)
59
+ audio = audio.set_frame_rate(16000).set_channels(1).set_sample_width(2)
60
+ sample_rate = audio.frame_rate
61
+ audio_data = audio.raw_data
62
+
63
+ # 2. Initialize VAD (0-3, higher value means more likely to be considered as speech)
64
+ vad = webrtcvad.Vad(0)
65
+
66
+ # 3. Generate frames
67
+ frames = list(_frame_generator(30, audio_data, sample_rate))
68
+
69
+ # 4. Detect vocal intervals
70
+ segments = _vad_collector(sample_rate, 30, 300, vad, frames)
71
+
72
+ # If no vocals, set both start and end to -1
73
+ if len(segments) == 0:
74
+ return {
75
+ "start": -1,
76
+ "end": -1,
77
+ "segments": [],
78
+ }
79
+
80
+ return {
81
+ "start": segments[0][0],
82
+ "end": segments[-1][1],
83
+ "segments": segments,
84
+ }
85
+
86
+ # ===== External Interface =====
87
+
88
+ def get_endpoints_meta(dataset:list[dict], save_path:str, max_workers:int=4, save_middle:bool=True):
89
+ """
90
+ Add endpoint labels to each audio in dataset (mainly for separated vocal audio)
91
+ - Requires 'path' field in each data entry in dataset
92
+ - Write fields: endpoints.start/end
93
+ - Write to save_path in real-time
94
+ - save_middle determines whether to record each sentence's endpoints to save.segments field
95
+ """
96
+ dataset = dup_remove(dataset, save_path, 'path', 'endpoints')
97
+ new_dataset = []
98
+ with open(save_path, 'a', encoding='utf-8') as file:
99
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
100
+ futures = {executor.submit(_one_process, ele['path']): ele for ele in dataset}
101
+ for future in tqdm(as_completed(futures), desc="Detecting endpoints"):
102
+ ele = futures[future] # Get original element
103
+ try:
104
+ result = future.result()
105
+ ele['endpoints'] = {
106
+ "start": result['start'],
107
+ "end": result['end']
108
+ }
109
+ if save_middle:
110
+ if "save" not in ele:
111
+ ele['save'] = {}
112
+ ele['save']['segments'] = result['segments']
113
+ new_dataset.append(ele)
114
+ json.dump(ele, file, ensure_ascii=False)
115
+ file.write("\n")
116
+ except Exception:
117
+ pass
118
+ return new_dataset
data_pipeline/meta_process/meta_lang.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from tqdm import tqdm
4
+ from funasr import AutoModel
5
+ from typing import List, Tuple
6
+ from collections import defaultdict
7
+ from my_tool import get_free_gpu, dup_remove
8
+
9
+ # ===== ASR Model (External) =====
10
+
11
+ def load_asr_model(bs:int):
12
+ """Load lyric recognition model"""
13
+ device = f"cuda:{get_free_gpu()}"
14
+ model = AutoModel(
15
+ model="iic/SenseVoiceSmall",
16
+ trust_remote_code=True,
17
+ vad_model="fsmn-vad",
18
+ vad_kwargs={"max_single_segment_time": 30000},
19
+ device=device,
20
+ batch_size=bs,
21
+ max_batch_size=bs * 2,
22
+ )
23
+ print(f"Using {device}")
24
+ return model
25
+
26
+ # ===== ASR Parsing =====
27
+
28
+ def _struct2lang(text: str) -> str:
29
+ """Extract language identifier from structured representation"""
30
+ start = text.find("|")
31
+ text = text[start+1:]
32
+ end = text.find("|")
33
+ return text[:end]
34
+
35
+ def _struct2lyrics(text:str) -> str:
36
+ start = text.rfind(">")
37
+ lyric = text[start+1:]
38
+ return lyric
39
+
40
+ def _struct_parse(text: str) -> Tuple[List[str], List[str]]:
41
+ """Split structured information sentence by sentence and then parse"""
42
+ texts = text.split(" <")
43
+ langs, lyrics = [], []
44
+ for ele in texts:
45
+ langs.append(_struct2lang(ele))
46
+ lyrics.append(_struct2lyrics(ele))
47
+ return langs, lyrics
48
+
49
+ # ===== ASR Processing =====
50
+
51
+ def _batch_asr(model, paths:List[str]) -> List[Tuple[List[str], List[str]]]:
52
+ """Batch speech recognition"""
53
+ outputs = model.generate(
54
+ input=paths,
55
+ cache=None,
56
+ language="auto",
57
+ use_itn=True,
58
+ batch_size_s=240,
59
+ merge_vad=True,
60
+ merge_length_s=15,
61
+ )
62
+ return [_struct_parse(output['text']) for output in outputs]
63
+
64
+ # ===== Overall Language Detection =====
65
+
66
+ def _lang_decide(lang_lyrics:list[tuple[str, str]], val_limit:int=5, word_limit=5) -> str:
67
+ """
68
+ Determine language based on sentence recognition information
69
+ - val_limit: Only count if there are at least this many sentences
70
+ - word_limit: Only count if a sentence has at least this many words
71
+ """
72
+ lang_count = defaultdict(int)
73
+ seg_langs, seg_lyrics = lang_lyrics
74
+ for lang, lyric in zip(seg_langs, seg_lyrics):
75
+ lyric = lyric.strip()
76
+ if lang == "en":
77
+ words_num = len(lyric.split())
78
+ else:
79
+ words_num = len(lyric)
80
+ if words_num >= word_limit:
81
+ lang_count[lang] += 1
82
+ langs = []
83
+ for lang, count in lang_count.items():
84
+ if count >= val_limit:
85
+ langs.append(lang)
86
+ if len(langs) == 0:
87
+ return "pure"
88
+ elif len(langs) == 1:
89
+ return langs[0]
90
+ else:
91
+ return "multi: " + " ".join(langs)
92
+
93
+ # ===== External Interface =====
94
+
95
+ def get_lang_meta(model, dataset:list[dict], bs:int, save_path:str, save_middle:bool=True) -> list[dict]:
96
+ """
97
+ Perform language recognition on a JSONL dataset
98
+ - Final language tag is saved to lang field, types include zh, en, ja, ko, yue, pure, multi, etc.
99
+ - save_middle determines whether to save intermediate recognition results (sentence languages and lyrics) to save.langs, save.lyrics
100
+ """
101
+ data_num = len(dataset)
102
+ dataset = dup_remove(dataset, save_path, 'path', 'lang')
103
+ new_dataset = []
104
+ with open(save_path, 'a', encoding='utf-8') as file:
105
+ for i in tqdm(range(0, data_num, bs), desc="Lang detecting"):
106
+ batch = []
107
+ paths = []
108
+ for ele in dataset[i:i+bs]:
109
+ path = ele['path']
110
+ if os.path.exists(path):
111
+ batch.append(ele)
112
+ paths.append(path)
113
+ lang_lyrics_lis = _batch_asr(model, paths)
114
+ langs = [_lang_decide(lang_lyrics) for lang_lyrics in lang_lyrics_lis]
115
+ for ele, (seg_langs, seg_lyrics), lang in zip(batch, lang_lyrics_lis, langs):
116
+ ele['lang'] = lang
117
+ if save_middle:
118
+ if 'save' not in ele:
119
+ ele['save'] = {}
120
+ ele['save']['langs'] = seg_langs
121
+ ele['save']['lyrics'] = seg_lyrics
122
+ new_dataset.append(ele)
123
+ json.dump(ele, file, ensure_ascii=False)
124
+ file.write("\n")
125
+ return new_dataset
data_pipeline/meta_process/meta_phonemes.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import copy
5
+ import jieba
6
+ import string
7
+ from tqdm import tqdm
8
+ from g2p_en import G2p
9
+ from my_tool import BASE_DIR, load_json
10
+ from pypinyin import pinyin, Style, load_phrases_dict
11
+ from pypinyin_dict.phrase_pinyin_data import cc_cedict
12
+
13
+ cc_cedict.load()
14
+ re_special_pinyin = re.compile(r'^(n|ng|m)$')
15
+ reference = load_json("poly_correct.json")
16
+ load_phrases_dict(reference)
17
+
18
+ # ===== Chinese Conversion =====
19
+
20
+ def _split_py(py):
21
+ """Split pinyin with tone number into initial (sm) and final (ym) parts"""
22
+ tone = py[-1]
23
+ py = py[:-1]
24
+ sm = ""
25
+ ym = ""
26
+ suf_r = ""
27
+ if re_special_pinyin.match(py):
28
+ py = 'e' + py
29
+ if py[-1] == 'r':
30
+ suf_r = 'r'
31
+ py = py[:-1]
32
+
33
+ if len(py) == 0:
34
+ # rx
35
+ return "", suf_r + tone
36
+
37
+ if py == 'zi' or py == 'ci' or py == 'si' or py == 'ri':
38
+ sm = py[:1]
39
+ ym = "ii"
40
+ elif py == 'zhi' or py == 'chi' or py == 'shi':
41
+ sm = py[:2]
42
+ ym = "iii"
43
+ elif py == 'ya' or py == 'yan' or py == 'yang' or py == 'yao' or py == 'ye' or py == 'yong' or py == 'you':
44
+ sm = ""
45
+ ym = 'i' + py[1:]
46
+ elif py == 'yi' or py == 'yin' or py == 'ying':
47
+ sm = ""
48
+ ym = py[1:]
49
+ elif py == 'yu' or py == 'yv' or py == 'yuan' or py == 'yvan' or py == 'yue ' or py == 'yve' or py == 'yun' or py == 'yvn':
50
+ sm = ""
51
+ ym = 'v' + py[2:]
52
+ elif py == 'wu':
53
+ sm = ""
54
+ ym = "u"
55
+ elif py[0] == 'w':
56
+ sm = ""
57
+ ym = "u" + py[1:]
58
+ elif len(py) >= 2 and (py[0] == 'j' or py[0] == 'q' or py[0] == 'x') and py[1] == 'u':
59
+ sm = py[0]
60
+ ym = 'v' + py[2:]
61
+ else:
62
+ seg_pos = re.search('a|e|i|o|u|v', py)
63
+ try:
64
+ sm = py[:seg_pos.start()]
65
+ ym = py[seg_pos.start():]
66
+ if ym == 'ui':
67
+ ym = 'uei'
68
+ elif ym == 'iu':
69
+ ym = 'iou'
70
+ elif ym == 'un':
71
+ ym = 'uen'
72
+ elif ym == 'ue':
73
+ ym = 've'
74
+ except Exception:
75
+ sm = ym = ""
76
+ return sm, ym
77
+ ym += suf_r + tone
78
+ return sm, ym
79
+
80
+ # All Chinese punctuation
81
+ chinese_punctuation_pattern = r'[\u3002\uff0c\uff1f\uff01\uff1b\uff1a\u201c\u201d\u2018\u2019\u300a\u300b\u3008\u3009\u3010\u3011\u300e\u300f\u2014\u2026\u3001\uff08\uff09]'
82
+
83
+ def _has_ch_punc(text):
84
+ match = re.search(chinese_punctuation_pattern, text)
85
+ return match is not None
86
+
87
+ def _has_en_punc(text):
88
+ return text in string.punctuation
89
+
90
+ def _trans_cn(text:str, with_sp=True):
91
+ """Convert Chinese to phonemes"""
92
+ phonemes = []
93
+ # Word segmentation
94
+ seg_list = jieba.cut(text)
95
+ # Process word by word
96
+ for seg in seg_list:
97
+ # String validity
98
+ if seg.strip() == "": continue
99
+ # seg_tn = tn_chinese(seg)
100
+ # Convert to pinyin (without tone)
101
+ py =[_py[0] for _py in pinyin(seg, style=Style.TONE3, neutral_tone_with_five=True)]
102
+ # Punctuation detection (skip if present)
103
+ if any([_has_ch_punc(_py) for _py in py]) or any([_has_en_punc(_py) for _py in py]):
104
+ continue
105
+ # Split pinyin
106
+ # phonemes += _split_py(_py)
107
+ for _py in py:
108
+ sm, ym = _split_py(_py)
109
+ if sm != "":
110
+ phonemes.append(sm)
111
+ if ym != "":
112
+ phonemes.append(ym)
113
+ if with_sp:
114
+ phonemes += ["sp"]
115
+ return phonemes
116
+
117
+ # ===== English Conversion =====
118
+
119
+ def _read_lexicon(lex_path):
120
+ """Read English lexicon"""
121
+ lexicon = {}
122
+ with open(lex_path) as f:
123
+ for line in f:
124
+ temp = re.split(r"\s+", line.strip("\n"))
125
+ word = temp[0]
126
+ phones = temp[1:]
127
+ if word.lower() not in lexicon:
128
+ lexicon[word.lower()] = phones
129
+ return lexicon
130
+
131
+ LEX_PATH = BASE_DIR / f"data/ref/lexion.txt"
132
+ lexicon = _read_lexicon(LEX_PATH)
133
+
134
+ g2p = G2p()
135
+
136
+ def _trans_en(word:str, with_sp=True):
137
+ """Convert English (word) to phonemes"""
138
+ w_lower = word.lower()
139
+ phonemes = []
140
+ if w_lower in lexicon:
141
+ # Use lexicon if available (cannot directly get reference)
142
+ phonemes += lexicon[w_lower]
143
+ else:
144
+ # Use G2P if not in lexicon
145
+ phonemes = g2p(w_lower)
146
+ if not phonemes:
147
+ phonemes = []
148
+ # Add to lexicon
149
+ lexicon[w_lower] = phonemes
150
+ if len(phonemes) > 0 and with_sp:
151
+ phonemes.append("sp")
152
+ return phonemes
153
+
154
+ # ===== Single Sentence Processing =====
155
+
156
+ def _char_lang(c:str) -> int:
157
+ """
158
+ Check if a character is Chinese, English, or other
159
+ 0 - Chinese
160
+ 1 - English
161
+ 2 - Number
162
+ 3 - Other
163
+ """
164
+ if '\u4e00' <= c <= '\u9fff':
165
+ return 0
166
+ elif ('a' <= c <= 'z') or ('A' <= c <= 'Z'):
167
+ return 1
168
+ elif c.isdigit():
169
+ return 2
170
+ else:
171
+ return 3
172
+
173
+ NUMBER_MAP = {
174
+ "0": "zero",
175
+ "1": "one",
176
+ "2": "two",
177
+ "3": "three",
178
+ "4": "four",
179
+ "5": "five",
180
+ "6": "six",
181
+ "7": "seven",
182
+ "8": "eight",
183
+ "9": "nine",
184
+ }
185
+
186
+ def _lang_seperate(text:str) -> list[str]:
187
+ """Split string by language"""
188
+ lang_segs = [] # Set of split strings
189
+ lang_tags = [] # Tags for each string segment
190
+ lang_seg = "" # Previous continuous language string
191
+ lang_tag = -1 # Language type of previous character
192
+ en_count = 0
193
+ for c in text:
194
+ lang = _char_lang(c)
195
+ if lang_tag != lang:
196
+ # Different from previous character type
197
+ if lang_seg != "":
198
+ lang_segs.append(lang_seg)
199
+ lang_tags.append(lang_tag)
200
+ if lang_tag == 1:
201
+ en_count += 1
202
+ lang_seg = ""
203
+ if lang == 2 and en_count >= 4:
204
+ # Number conversion in English
205
+ lang_segs.append(NUMBER_MAP[c])
206
+ lang_tags.append(1)
207
+ lang_tag = lang
208
+ if lang < 2:
209
+ lang_seg += c
210
+ if lang_seg != "":
211
+ # Last valid segment
212
+ lang_segs.append(lang_seg)
213
+ lang_tags.append(lang_tag)
214
+ return lang_segs, lang_tags
215
+
216
+ def _phoneme_trans(text:str, with_sp=True):
217
+ """Convert a lyric segment to phonemes"""
218
+ # Split by language
219
+ lang_segs, lang_tags = _lang_seperate(text)
220
+ # Convert segment by segment
221
+ phonemes = []
222
+ for lang_seg, lang_tag in zip(lang_segs, lang_tags):
223
+ if lang_tag == 0:
224
+ # Chinese
225
+ phonemes += _trans_cn(lang_seg, with_sp)
226
+ else:
227
+ # English
228
+ phonemes += _trans_en(lang_seg, with_sp)
229
+ return phonemes
230
+
231
+ # ===== Dynamic Adaptation =====
232
+
233
+ def _get_lyrics(raw_content:str) -> list[str]:
234
+ """Extract lyric content from dialogue, format like '[stage][dsec:xxx][lyrics:xxx\nxxx]'"""
235
+ START_FORMAT = "[lyrics:"
236
+ start = raw_content.find(START_FORMAT)
237
+ if start == -1:
238
+ return None, None
239
+ content = raw_content[start+len(START_FORMAT):-1]
240
+ # Filter brackets
241
+ content = re.sub(r'\[.*?\]', '', content) # Complete brackets
242
+ content = re.sub(r'[\[\]]', '', content) # Unclosed brackets
243
+ # Split sentences
244
+ sentences = content.split("\n")
245
+ # Reconstruct
246
+ new_content = raw_content[:start] + START_FORMAT + content + "]"
247
+ return sentences, new_content
248
+
249
+ def _trans_sentences(sentences:list[str], with_sp:bool=True) -> str:
250
+ """Convert sentence list to wrapped phoneme string"""
251
+ phonemes_lis = []
252
+ for sentence in sentences:
253
+ phonemes = _phoneme_trans(sentence, with_sp)
254
+ phonemes_lis.append(" ".join(phonemes))
255
+ # Wrap
256
+ phonemes_str = '\n'.join(phonemes_lis)
257
+ envelope = f"[phoneme:{phonemes_str}]"
258
+ envelope = re.sub(r'\d+', '', envelope) # Remove tones
259
+ return envelope
260
+
261
+ # ===== External Interface =====
262
+
263
+ def get_phonemes_meta(dataset:list[dict], save_path:str, with_sp:bool=True):
264
+ """Add phonemes to lyrics in dataset"""
265
+ new_dataset = []
266
+ with open(save_path, 'w', encoding='utf-8') as file:
267
+ for ele in tqdm(dataset, desc="Phoneme trans"):
268
+ ele = copy.deepcopy(ele)
269
+ messages = ele['messages']
270
+ # Skip first message, process subsequent ones sentence by sentence
271
+ for message in messages[1:]:
272
+ if message['role'] == "assistant":
273
+ continue
274
+ content = message['content']
275
+ sentences, new_content = _get_lyrics(content)
276
+ if sentences is None:
277
+ continue
278
+ phonemes = _trans_sentences(sentences, with_sp)
279
+ message['content'] = new_content + phonemes
280
+ new_dataset.append(ele)
281
+ json.dump(ele, file, ensure_ascii=False)
282
+ file.write("\n")
283
+ return new_dataset
data_pipeline/meta_process/meta_tags.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from tqdm import tqdm
5
+ from transformers import (
6
+ Qwen3OmniMoeProcessor,
7
+ Qwen3OmniMoeForConditionalGeneration
8
+ )
9
+ from qwen_omni_utils import process_mm_info
10
+ from my_tool import get_free_gpu, audio_cut, extract_json, dup_remove, BASE_DIR
11
+
12
+ # ===== Tag Model and Processor (External) =====
13
+
14
+ def load_tag_model():
15
+ """Load tag model"""
16
+ device = f"cuda:{get_free_gpu()}"
17
+ print(f"Using {device}")
18
+ model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
19
+ model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
20
+ model_name,
21
+ dtype=torch.bfloat16,
22
+ # local_files_only=True
23
+ ).to(device)
24
+ model.disable_talker()
25
+ model.eval()
26
+
27
+ processor = Qwen3OmniMoeProcessor.from_pretrained(
28
+ model_name,
29
+ # local_files_only=True
30
+ )
31
+ return model, processor
32
+
33
+ # ===== Tag Annotation =====
34
+
35
+ def _format_messages(prompt:str, path:str) -> list[dict]:
36
+ """Construct messages to pass to omni"""
37
+ messages = [
38
+ {
39
+ "role": "system",
40
+ "content": [
41
+ {"type": "text", "text": prompt}
42
+ ]
43
+ },
44
+ {
45
+ "role": "user",
46
+ "content": [
47
+ {"type": "audio", "audio": path},
48
+ ]
49
+ }
50
+ ]
51
+ return messages
52
+
53
+ def _batch_tagging(model, processor, paths:list[str], prompt:str, mode="random"):
54
+ """Annotate a batch of songs"""
55
+ convs = []
56
+ middle_paths = []
57
+ output_dir = BASE_DIR / "data/temp"
58
+ for path in paths:
59
+ seg_path = audio_cut(path, mode, output_dir)
60
+ middle_paths.append(seg_path)
61
+ messages = _format_messages(prompt, seg_path)
62
+ convs.append(messages)
63
+
64
+ USE_AUDIO_IN_VIDEO = False
65
+
66
+ with torch.no_grad():
67
+ text = processor.apply_chat_template(convs, add_generation_prompt=True, tokenize=False)
68
+ audios, images, videos = process_mm_info(convs, use_audio_in_video=USE_AUDIO_IN_VIDEO)
69
+ inputs = processor(
70
+ text=text,
71
+ audio=audios,
72
+ padding=True,
73
+ images=images,
74
+ videos=videos,
75
+ return_tensors="pt",
76
+ use_audio_in_video=USE_AUDIO_IN_VIDEO
77
+ )
78
+ inputs = inputs.to(model.device).to(model.dtype)
79
+
80
+ text_ids = model.generate(
81
+ **inputs,
82
+ max_new_tokens=2048,
83
+ return_audio=False,
84
+ thinker_return_dict_in_generate=True,
85
+ use_audio_in_video=USE_AUDIO_IN_VIDEO
86
+ )
87
+ gene_texts = processor.batch_decode(
88
+ text_ids[0].sequences[:, inputs["input_ids"].shape[1] :],
89
+ skip_special_tokens=True,
90
+ clean_up_tokenization_spaces=False
91
+ )
92
+
93
+ torch.cuda.empty_cache()
94
+ # Delete audio segments
95
+ for path in middle_paths:
96
+ if os.path.exists(path):
97
+ os.remove(path)
98
+ return gene_texts
99
+
100
+ # ===== External Interface =====
101
+
102
+ def get_tags_meta(model, processor, dataset:list[dict], prompt:str, bs:int, save_path:str):
103
+ data_num = len(dataset)
104
+ dataset = dup_remove(dataset, save_path, 'path', 'tags')
105
+ new_dataset = []
106
+ with open(save_path, 'a', encoding="utf-8") as file:
107
+ for i in tqdm(range(0, data_num, bs)):
108
+ batch = []
109
+ paths = []
110
+ for ele in dataset[i:i+bs]:
111
+ path = ele['path']
112
+ if os.path.exists(path):
113
+ batch.append(ele)
114
+ paths.append(path)
115
+ contents = _batch_tagging(model, processor, paths, prompt)
116
+ for ele, content in zip(batch, contents):
117
+ ckeck, json_data = extract_json(content)
118
+ if not ckeck:
119
+ continue
120
+ ele['tags'] = json_data['tags']
121
+ new_dataset.append(ele)
122
+ json.dump(ele, file, ensure_ascii=False)
123
+ file.write('\n')
124
+ return new_dataset
data_pipeline/meta_process/meta_vocal.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from typing import List, Dict
4
+
5
+ import torch
6
+ import librosa
7
+ import numpy as np
8
+
9
+ from demucs.pretrained import get_model
10
+ from demucs.apply import apply_model
11
+
12
+
13
+ # ======================
14
+ # Basic Configuration
15
+ # ======================
16
+
17
+ SAMPLE_RATE = 44100
18
+ MAX_DURATION = 5 # Only take first 30 seconds
19
+ BATCH_SIZE = 8 # Can adjust to 16~32 for 80G GPU memory
20
+ VOCAL_DB_THRESHOLD = -35.0 # Vocal presence threshold (empirical value)
21
+ # DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+ DEVICE = "cuda:1"
23
+
24
+
25
+ # ======================
26
+ # Audio Loading (first 30s only)
27
+ # ======================
28
+
29
+ def load_audio_30s(path: str, sr: int = SAMPLE_RATE) -> torch.Tensor:
30
+ """
31
+ Returns shape: [channels=2, samples]
32
+ """
33
+ y, _ = librosa.load(
34
+ path,
35
+ sr=sr,
36
+ mono=False,
37
+ duration=MAX_DURATION
38
+ )
39
+
40
+ if y.ndim == 1:
41
+ y = np.stack([y, y], axis=0)
42
+
43
+ return torch.from_numpy(y)
44
+
45
+
46
+ # ======================
47
+ # dB Calculation
48
+ # ======================
49
+
50
+ def rms_db(wav: torch.Tensor) -> float:
51
+ """
52
+ wav: [channels, samples]
53
+ """
54
+ rms = torch.sqrt(torch.mean(wav ** 2))
55
+ db = 20 * torch.log10(rms + 1e-8)
56
+ return db.item()
57
+
58
+
59
+ # ======================
60
+ # Demucs Vocal Detection
61
+ # ======================
62
+
63
+ class DemucsVocalDetector:
64
+ def __init__(self):
65
+ self.model = (
66
+ get_model("htdemucs")
67
+ .to(DEVICE)
68
+ .eval()
69
+ )
70
+ self.vocal_idx = self.model.sources.index("vocals")
71
+
72
+ @torch.no_grad()
73
+ def batch_has_vocal(self, audio_paths: List[str]) -> Dict[str, bool]:
74
+ """
75
+ Input: List of audio paths
76
+ Output: {path: whether has vocals}
77
+ """
78
+ results = {}
79
+
80
+ batch_wavs = []
81
+ batch_paths = []
82
+
83
+ print("start load")
84
+ for path in audio_paths:
85
+ try:
86
+ wav = load_audio_30s(path)
87
+ batch_wavs.append(wav)
88
+ batch_paths.append(path)
89
+ except Exception as e:
90
+ results[path] = False
91
+ continue
92
+
93
+ print("finish load")
94
+ self._process_batch(batch_wavs, batch_paths, results)
95
+
96
+ return results
97
+
98
+ def _process_batch(self, wavs, paths, results):
99
+ max_len = max(w.shape[1] for w in wavs)
100
+ padded = []
101
+
102
+ for w in wavs:
103
+ if w.shape[1] < max_len:
104
+ w = torch.nn.functional.pad(w, (0, max_len - w.shape[1]))
105
+ padded.append(w)
106
+
107
+ batch = torch.stack(padded, dim=0).to(DEVICE)
108
+
109
+ print("start demucs")
110
+ torch.cuda.synchronize()
111
+
112
+ sources = apply_model(
113
+ self.model,
114
+ batch,
115
+ SAMPLE_RATE,
116
+ device=DEVICE,
117
+ split=False, # 🔥 Core
118
+ progress=False
119
+ )
120
+
121
+ torch.cuda.synchronize()
122
+ print("demucs done")
123
+
124
+ vocals = sources[:, self.vocal_idx]
125
+
126
+ for i, path in enumerate(paths):
127
+ db = rms_db(vocals[i])
128
+ results[path] = db > VOCAL_DB_THRESHOLD
129
+
130
+ # ======================
131
+ # Usage Example
132
+ # ======================
133
+
134
+ if __name__ == "__main__":
135
+ audio_list = []
136
+
137
+ detector = DemucsVocalDetector()
138
+ result = detector.batch_has_vocal(audio_list)
139
+
140
+ for k, v in result.items():
141
+ print(f"{k}: {'Has' if v else 'No'} vocals")
data_pipeline/meta_process/my_tool.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import random
5
+ import tarfile
6
+ import subprocess
7
+ import json_repair
8
+ from tqdm import tqdm
9
+ from pathlib import Path
10
+ from pydub import AudioSegment
11
+ from collections import defaultdict
12
+ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
13
+
14
+ # ===== Macros =====
15
+
16
+ BASE_DIR = Path(__file__).parent
17
+
18
+ # ===== Helper Functions =====
19
+
20
+ def pure_name(path:str):
21
+ """Get the original name of a file path (without extension)"""
22
+ basename = os.path.basename(path)
23
+ dot_pos = basename.rfind('.')
24
+ if dot_pos == -1:
25
+ return basename
26
+ return basename[:dot_pos]
27
+
28
+ def extract_json(text: str) -> tuple[bool, dict]:
29
+ """Extract and repair JSON data from text (enhanced error-tolerant version)
30
+
31
+ Features:
32
+ 1. Automatically identify code block markers (```json``` or ```)
33
+ 2. Fix common JSON errors (mismatched quotes, trailing commas, etc.)
34
+ 3. Support lenient parsing mode
35
+
36
+ Returns: (success, parsed dictionary)
37
+ """
38
+ # Preprocessing: extract possible JSON content area
39
+ content = text
40
+
41
+ # Case 1: Check ```json``` code block
42
+ if '```json' in text:
43
+ start = text.find('```json')
44
+ end = text.find('```', start + 6)
45
+ content = text[start + 6:end].strip()
46
+ # Case 2: Check regular ``` code block
47
+ elif '```' in text:
48
+ start = text.find('```')
49
+ end = text.find('```', start + 3)
50
+ content = text[start + 3:end].strip()
51
+
52
+ # Clean common interference items in content
53
+ content = re.sub(r'^[^{[]*', '', content) # Remove unstructured content before JSON
54
+ content = re.sub(r'[^}\]]*$', '', content) # Remove unstructured content after JSON
55
+
56
+ # Try standard parsing
57
+ try:
58
+ json_data = json.loads(content)
59
+ return True, json_data
60
+ except json.JSONDecodeError as e:
61
+ standard_error = e
62
+
63
+ # Try to repair with json_repair
64
+ try:
65
+ repaired = json_repair.repair_json(content)
66
+ json_data = json.loads(repaired)
67
+ return True, json_data
68
+ except Exception as e:
69
+ repair_error = e
70
+ return False, {
71
+ "standard_error": standard_error,
72
+ "repair_error": repair_error
73
+ }
74
+
75
+ def path_join(dir, name):
76
+ return os.path.join(dir, name)
77
+
78
+ def dict_sort_print(dic:dict, value:bool=True, reverse=True):
79
+ """Sort a dictionary by value size and output"""
80
+ idx = 1 if value else 0
81
+ sorted_lis = sorted(dic.items(), key=lambda x: x[idx], reverse=reverse)
82
+ sorted_dic = {}
83
+ for key, value in sorted_lis:
84
+ sorted_dic[key] = value
85
+ print(json.dumps(sorted_dic, indent=4, ensure_ascii=False))
86
+
87
+ def clean_newlines(text: str) -> str:
88
+ """
89
+ Clean lyric line breaks:
90
+ 1. Keep line breaks after punctuation
91
+ 2. Convert line breaks after non-punctuation → space
92
+ 3. Fix extra spaces after English apostrophes
93
+ 4. Merge redundant spaces
94
+ 5. Preserve paragraph structure, ensure line breaks after punctuation
95
+ """
96
+ if not text:
97
+ return ""
98
+
99
+ text = text.strip()
100
+
101
+ # First unify line breaks to \n
102
+ text = text.replace('\r\n', '\n').replace('\r', '\n')
103
+
104
+ # Merge non-empty lines into one sentence (remove original line breaks first)
105
+ lines = [line.strip() for line in text.split('\n')]
106
+ text = ' '.join(line for line in lines if line)
107
+
108
+ # Add line break after sentence-ending punctuation (Chinese and English punctuation)
109
+ text = re.sub(r'([.,!?:;,。!?;])\s*', r'\1\n', text)
110
+
111
+ # Fix spaces after English apostrophes
112
+ text = re.sub(r"'\s+", "'", text)
113
+
114
+ # Merge redundant spaces
115
+ text = re.sub(r'[ \t]+', ' ', text)
116
+
117
+ # Remove leading and trailing spaces from lines
118
+ text = '\n'.join(line.strip() for line in text.split('\n'))
119
+
120
+ return text.strip()
121
+
122
+ # ===== Detection Functions =====
123
+ def is_ch_char(char:str):
124
+ """Determine if a single character is a Chinese character"""
125
+ if len(char) != 1:
126
+ return False
127
+
128
+ # Unicode ranges for Chinese characters
129
+ # 1. Basic Chinese: 0x4E00-0x9FFF
130
+ # 2. Extension A: 0x3400-0x4DBF
131
+ # 3. Extension B: 0x20000-0x2A6DF
132
+ # 4. Extension C: 0x2A700-0x2B73F
133
+ # 5. Extension D: 0x2B740-0x2B81F
134
+ # 6. Extension E: 0x2B820-0x2CEAF
135
+
136
+ code = ord(char)
137
+
138
+ # Common check (covers most cases)
139
+ if 0x4E00 <= code <= 0x9FFF:
140
+ return True
141
+ # Extension A
142
+ if 0x3400 <= code <= 0x4DBF:
143
+ return True
144
+ # Other extensions not considered for now
145
+
146
+ return False
147
+
148
+ # ===== File Operations =====
149
+
150
+ def load_txt(path:str) -> str:
151
+ """Load a file as plain text"""
152
+ with open(path, 'r') as file:
153
+ content = file.read()
154
+ return content
155
+
156
+ def load_json(path:str):
157
+ """Load a JSON file"""
158
+ if not os.path.exists(path):
159
+ return {}
160
+ with open(path, 'r') as file:
161
+ data = json.load(file)
162
+ return data
163
+
164
+ def load_jsonl(path:str, limit=-1) -> list[dict]:
165
+ """Load a JSONL file"""
166
+ data = []
167
+ with open(path, 'r') as file:
168
+ for id, line in tqdm(enumerate(file), desc=f"Loading {path}"):
169
+ if limit != -1 and id == limit:
170
+ break
171
+ data.append(json.loads(line))
172
+ return data
173
+
174
+ def save_json(data, path:str):
175
+ """Save a JSON file"""
176
+ with open(path, 'w', encoding='utf-8') as file:
177
+ json.dump(data, file, ensure_ascii=False, indent=4)
178
+
179
+ def save_jsonl(data:list[dict], path:str, mode='w'):
180
+ """Save a JSONL file"""
181
+ with open(path, mode, encoding='utf-8') as file:
182
+ for ele in tqdm(data, desc=f"Saving to {path}"):
183
+ json.dump(ele, file, ensure_ascii=False)
184
+ file.write("\n")
185
+
186
+ def audio_cut(input_path, mode:str, output_dir:str, segment_length:int=30000):
187
+ """
188
+ Extract a segment of specified length from an audio file
189
+ - mode: Cut type (random / middle)
190
+ - output_dir: Output folder
191
+ - segment_length: Segment length (milliseconds)
192
+ """
193
+ assert mode in ['random', 'middle']
194
+
195
+ # Check if file exists
196
+ if not os.path.exists(input_path):
197
+ raise FileNotFoundError(f"Audio file not found: {input_path}")
198
+
199
+ # Load audio file
200
+ audio = AudioSegment.from_file(input_path)
201
+ audio = audio.set_frame_rate(44100).set_channels(1) # Set sample rate and channels
202
+ audio_duration = len(audio) # Duration control
203
+
204
+ # If audio length is less than target segment length, use entire audio
205
+ if audio_duration <= segment_length:
206
+ print(f"Warning: Audio too short ({audio_duration}ms), using full audio: {input_path}")
207
+ segment = audio
208
+ else:
209
+ # Calculate slice position based on mode
210
+ if mode == "random":
211
+ # Random cut
212
+ max_start = max(0, audio_duration - segment_length)
213
+ start = random.randint(0, max_start)
214
+ end = start + segment_length
215
+ else:
216
+ # Cut from middle
217
+ middle_point = audio_duration // 2
218
+ start = max(0, middle_point - (segment_length // 2))
219
+ end = min(audio_duration, start + segment_length)
220
+
221
+ # If cutting from middle would exceed boundaries, adjust start position
222
+ if end > audio_duration:
223
+ end = audio_duration
224
+ start = end - segment_length
225
+ elif start < 0:
226
+ start = 0
227
+ end = segment_length
228
+
229
+ # Ensure slice range is valid
230
+ start = max(0, min(start, audio_duration))
231
+ end = max(0, min(end, audio_duration))
232
+
233
+ if start >= end:
234
+ raise ValueError(f"Invalid slice range: start={start}, end={end}, duration={audio_duration}")
235
+
236
+ # Execute slice
237
+ segment = audio[start:end]
238
+
239
+ # Generate output path
240
+ basename = pure_name(input_path)
241
+ output_path = os.path.join(output_dir, f"seg_{basename}.wav")
242
+
243
+ # Save segment
244
+ segment.export(
245
+ output_path,
246
+ format="wav",
247
+ codec="pcm_s16le", # 16-bit little-endian encoding
248
+ parameters=["-acodec", "pcm_s16le"] # ffmpeg parameters
249
+ )
250
+ return output_path
251
+
252
+ def format_meta(dir:str, show:bool=True) -> list[dict]:
253
+ """Recursively get all audio paths (wav / mp3) in a folder and build JSONL"""
254
+ if not os.path.isdir(dir):
255
+ return []
256
+ dataset = []
257
+ if show:
258
+ for name in tqdm(os.listdir(dir), desc=f"Formating {dir}"):
259
+ path = os.path.join(dir, name)
260
+ if os.path.isdir(path):
261
+ dataset += format_meta(path, False)
262
+ elif name.endswith('.mp3') or name.endswith('.wav'):
263
+ dataset.append({"path": path})
264
+ else:
265
+ for name in os.listdir(dir):
266
+ path = os.path.join(dir, name)
267
+ if os.path.isdir(path):
268
+ dataset += format_meta(path, False)
269
+ elif name.endswith('.mp3') or name.endswith('.wav'):
270
+ dataset.append({"path": path})
271
+ return dataset
272
+
273
+ def dup_remove(raw_data:list[dict], save_path:str, key:str, seg:str):
274
+ """
275
+ Remove already generated items from dataset
276
+ - key is the primary key in raw dataset, foreign key in save
277
+ - seg is the target field
278
+ """
279
+ if not os.path.exists(save_path):
280
+ print(f"Dup num: 0")
281
+ return raw_data
282
+ save_data = load_jsonl(save_path)
283
+ keys = set()
284
+ for ele in tqdm(save_data, desc="Constructing Dup Set"):
285
+ if seg in ele:
286
+ keys.add(ele[key])
287
+ rest_data = []
288
+ dup_count = 0
289
+ for ele in tqdm(raw_data, desc="Checking Dup"):
290
+ if ele[key] not in keys:
291
+ rest_data.append(ele)
292
+ else:
293
+ dup_count += 1
294
+ print(f"Dup num: {dup_count}")
295
+ return rest_data
296
+
297
+ def tar_size_check(data_dir:str, subfixes:list[str], per:int, max_size:int):
298
+ """
299
+ Determine the number of files that can fit in a block before compression (assuming uniform file sizes)
300
+ - data_dir: Folder to compress
301
+ - subfixes: File suffixes to compress (e.g., .mp3)
302
+ - per: Check every N files on average
303
+ - max_size: Maximum limit in GB
304
+ """
305
+ names = sorted(list(os.listdir(data_dir)))
306
+ count = 0
307
+ size_sum = 0
308
+ for name in tqdm(names, desc="Size Checking"):
309
+ path = os.path.join(data_dir, name)
310
+ subfix = os.path.splitext(name)[1]
311
+ if subfix not in subfixes:
312
+ continue
313
+ count += 1
314
+ size_sum += os.path.getsize(path)
315
+ if count % per == 0:
316
+ gb_size = size_sum / 1024 / 1024 / 1024
317
+ if gb_size > max_size:
318
+ break
319
+ print(f"Count: {count}, Size: {gb_size:.2f}GB")
320
+
321
+ def tar_dir(
322
+ data_dir:str,
323
+ subfixes:list[str],
324
+ save_dir:str,
325
+ group_size:int,
326
+ tmp_dir:str,
327
+ mark:str,
328
+ max_workers:int=10,
329
+ ):
330
+ """Compress files in a directory in chunks (non-recursive)"""
331
+ names = sorted(list(os.listdir(data_dir)))
332
+ file_num = len(names)
333
+ for i in range(0, file_num, group_size):
334
+ names_subset = names[i:i+group_size]
335
+ size_sum = 0
336
+ name_path = os.path.join(tmp_dir, f"name_{i}_{mark}")
337
+ with open(name_path, 'w', encoding='utf-8') as file:
338
+ for name in tqdm(names_subset, desc=f"Counting Block {i}"):
339
+ path = os.path.join(data_dir, name)
340
+ subfix = os.path.splitext(path)[1]
341
+ if subfix not in subfixes:
342
+ continue
343
+ file.write("./" + name + "\n")
344
+ size_sum += os.path.getsize(path)
345
+ gb_size = size_sum / 1024 / 1024 / 1024
346
+ print(f"Zipping block {i+1}, size: {gb_size:.2f}GB")
347
+
348
+ tar_cmd = [
349
+ 'tar',
350
+ '--no-recursion',
351
+ '--files-from', str(name_path),
352
+ '-cf', '-'
353
+ ]
354
+ pigz_cmd = ['pigz', '-p', str(max_workers), '-c']
355
+
356
+ tar_process = subprocess.Popen(tar_cmd, stdout=subprocess.PIPE, cwd=data_dir)
357
+ pigz_process = subprocess.Popen(pigz_cmd, stdin=tar_process.stdout, stdout=subprocess.PIPE, cwd=data_dir)
358
+
359
+ save_path = os.path.join(save_dir, f"block_{i}_{mark}.tar.gz")
360
+ with open(save_path, 'wb') as out_file:
361
+ while True:
362
+ data = pigz_process.stdout.read(4096)
363
+ if not data:
364
+ break
365
+ out_file.write(data)
366
+
367
+ tar_process.wait()
368
+ pigz_process.wait()
369
+
370
+ if tar_process.returncode == 0 and pigz_process.returncode == 0:
371
+ print(f"Compression completed: {save_path}")
372
+ else:
373
+ print(f"Compression failed: tar return code={tar_process.returncode}, pigz return code={pigz_process.returncode}")
374
+
375
+ def music_avg_size(dir:str):
376
+ """Average music size (MB), length (s)"""
377
+ dataset = format_meta(dir)
378
+ dataset = dataset[:50]
379
+ size_sum = 0
380
+ length_sum = 0
381
+ for ele in tqdm(dataset, desc=f"Counting Music Size in {dir}"):
382
+ path = ele['path']
383
+ audio = AudioSegment.from_file(path)
384
+ length_sum += len(audio) / 1000.0
385
+ size_sum += os.path.getsize(path)
386
+ size_avg = size_sum / len(dataset) / 1024 / 1024
387
+ length_avg = length_sum / len(dataset)
388
+ return size_avg, length_avg
389
+
390
+ def get_sample(path:str, save_path:str="tmp.jsonl", num:int=100):
391
+ """Get N records from a JSONL file"""
392
+ if not os.path.exists(path):
393
+ return
394
+ if path.endswith(".jsonl"):
395
+ dataset = load_jsonl(path)
396
+ elif path.endswith(".json"):
397
+ dataset = load_json(path)
398
+ else:
399
+ print(f"Unsupport file: {path}")
400
+ return
401
+ sub_dataset = random.sample(dataset, num)
402
+ save_jsonl(sub_dataset, save_path)
403
+
404
+ def _get_field_one(path:str, field:str):
405
+ """Process data from one path"""
406
+ with open(path, 'r') as file:
407
+ data = json.load(file)
408
+ new_data = {
409
+ "id": f"{data['song_id']}_{data['track_index']}",
410
+ field: data[field]
411
+ }
412
+ return new_data
413
+
414
+ def get_field_suno(dir:str, save_path:str, field:str, max_workers:int=8):
415
+ """Extract a specific field from scattered JSON files in suno"""
416
+ paths = []
417
+ for name in tqdm(os.listdir(dir), desc="Getting names"):
418
+ if not name.endswith(".json"):
419
+ continue
420
+ paths.append(os.path.join(dir, name))
421
+
422
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
423
+ futures = [executor.submit(_get_field_one, path, field) for path in paths]
424
+ with open(save_path, 'w', encoding='utf-8') as file:
425
+ with tqdm(total=len(paths), desc="Processing the JSONs") as pbar:
426
+ for future in as_completed(futures):
427
+ result = future.result()
428
+ json.dump(result, file, ensure_ascii=False)
429
+ file.write("\n")
430
+ pbar.update(1)
431
+
432
+ def find_json(dir:str) -> list[str]:
433
+ """Find JSONL / JSON files in a folder"""
434
+ names = []
435
+ for name in tqdm(os.listdir(dir), desc="Finding JSON/JSONL"):
436
+ if name.endswith(".json") or name.endswith(".jsonl"):
437
+ names.append(name)
438
+ return names
439
+
440
+ def show_dir(dir:str):
441
+ """Display all contents in a directory"""
442
+ if not os.path.isdir(dir):
443
+ return
444
+ for name in os.listdir(dir):
445
+ print(name)
446
+
447
+ def _convert_mp3(path:str, dir:str):
448
+ """Process a single audio file"""
449
+ purename = pure_name(path)
450
+ output_path = os.path.join(dir, purename + ".mp3")
451
+ if os.path.exists(output_path):
452
+ # Already completed
453
+ return "pass"
454
+ try:
455
+ audio = AudioSegment.from_file(path)
456
+ except Exception:
457
+ # Failed to read file
458
+ print(f"fail to load {path}")
459
+ return "fail"
460
+ audio.export(output_path, format='mp3')
461
+ return "finish"
462
+
463
+ def convert_mp3(meta_path:str, dir:str, max_workers:int=10):
464
+ """Convert all specified audio files to mp3 and save in specified directory"""
465
+ os.makedirs(dir, exist_ok=True)
466
+ dataset = load_jsonl(meta_path)
467
+ pass_num = 0
468
+ finish_num = 0
469
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
470
+ futures = [executor.submit(_convert_mp3, ele['path'], dir) for ele in dataset]
471
+ with tqdm(total=len(dataset), desc=f"Converting {meta_path}") as pbar:
472
+ for future in as_completed(futures):
473
+ res = future.result()
474
+ if res == "pass":
475
+ pass_num += 1
476
+ else:
477
+ finish_num += 1
478
+ pbar.update(1)
479
+ print(f"Finish {finish_num}, Pass {pass_num}")
480
+
481
+ # ===== GPU and Models =====
482
+
483
+ def get_free_gpu() -> int:
484
+ """Return the GPU ID with the least memory usage"""
485
+ cmd = "nvidia-smi --query-gpu=index,memory.free --format=csv,noheader,nounits"
486
+ result = subprocess.check_output(cmd.split()).decode().strip().split("\n")
487
+
488
+ free_list = []
489
+ for line in result:
490
+ idx, free_mem = line.split(",")
491
+ free_list.append((int(idx), int(free_mem))) # (GPU id, free memory MiB)
492
+
493
+ # Sort by remaining memory
494
+ free_list.sort(key=lambda x: x[1], reverse=True)
495
+ return free_list[0][0]
496
+
497
+ # ===== Data Analysis =====
498
+
499
+ def compose_analyze(dataset:list[dict]):
500
+ """Statistical analysis of music structure composition"""
501
+ # Label count
502
+ labels = defaultdict(int)
503
+ for ele in tqdm(dataset):
504
+ segments = ele['segments']
505
+ for segment in segments:
506
+ label = segment['label']
507
+ labels[label] += 1
508
+ print(f"Number of labels: {len(labels)}")
509
+ print(dict_sort_print(labels))
510
+
511
+ # Different combinations
512
+ label_combs = defaultdict(int)
513
+ for ele in tqdm(dataset):
514
+ segments = ele['segments']
515
+ labels = []
516
+ for segment in segments:
517
+ label = segment['label']
518
+ labels.append(label)
519
+ if len(labels) == 0:
520
+ continue
521
+ label_comb = " | ".join(labels)
522
+ label_combs[label_comb] += 1
523
+ print(f"Number of combinations: {len(label_combs)}")
524
+ print(dict_sort_print(label_combs))
525
+
526
+ def _filter_tag(content:str) -> list[str]:
527
+ """Split and format tag fields"""
528
+ tags = []
529
+ raws = re.split(r'[,,.]', content)
530
+ for raw in raws:
531
+ raw = raw.strip().lower() # Remove spaces and convert to lowercase
532
+ if raw == "":
533
+ continue
534
+ seg_pos = raw.find(":")
535
+ if seg_pos != -1:
536
+ # If colon exists, only take the part after it
537
+ tag = raw[seg_pos+1:].strip()
538
+ else:
539
+ tag = raw
540
+ tags.append(tag)
541
+ return tags
542
+
543
+ def tags_analyze(dataset:list[dict]):
544
+ """Song tag analysis"""
545
+ tag_count = defaultdict(int)
546
+ for ele in tqdm(dataset, desc="Tag analyzing"):
547
+ tags = _filter_tag(ele['style'])
548
+ for tag in tags:
549
+ tag_count[tag] += 1
550
+ print(f"Number of tags: {len(tag_count.keys())}")
551
+ print(dict_sort_print(tag_count))