| | import os |
| | import re |
| | import json |
| | import random |
| | import tarfile |
| | import subprocess |
| | import json_repair |
| | from tqdm import tqdm |
| | from pathlib import Path |
| | from pydub import AudioSegment |
| | from collections import defaultdict |
| | from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed |
| |
|
| | |
| |
|
| | BASE_DIR = Path(__file__).parent |
| |
|
| | |
| |
|
| | def pure_name(path:str): |
| | """Get the original name of a file path (without extension)""" |
| | basename = os.path.basename(path) |
| | dot_pos = basename.rfind('.') |
| | if dot_pos == -1: |
| | return basename |
| | return basename[:dot_pos] |
| |
|
| | def extract_json(text: str) -> tuple[bool, dict]: |
| | """Extract and repair JSON data from text (enhanced error-tolerant version) |
| | |
| | Features: |
| | 1. Automatically identify code block markers (```json``` or ```) |
| | 2. Fix common JSON errors (mismatched quotes, trailing commas, etc.) |
| | 3. Support lenient parsing mode |
| | |
| | Returns: (success, parsed dictionary) |
| | """ |
| | |
| | content = text |
| | |
| | |
| | if '```json' in text: |
| | start = text.find('```json') |
| | end = text.find('```', start + 6) |
| | content = text[start + 6:end].strip() |
| | |
| | elif '```' in text: |
| | start = text.find('```') |
| | end = text.find('```', start + 3) |
| | content = text[start + 3:end].strip() |
| | |
| | |
| | content = re.sub(r'^[^{[]*', '', content) |
| | content = re.sub(r'[^}\]]*$', '', content) |
| | |
| | |
| | try: |
| | json_data = json.loads(content) |
| | return True, json_data |
| | except json.JSONDecodeError as e: |
| | standard_error = e |
| | |
| | |
| | try: |
| | repaired = json_repair.repair_json(content) |
| | json_data = json.loads(repaired) |
| | return True, json_data |
| | except Exception as e: |
| | repair_error = e |
| | return False, { |
| | "standard_error": standard_error, |
| | "repair_error": repair_error |
| | } |
| |
|
| | def path_join(dir, name): |
| | return os.path.join(dir, name) |
| |
|
| | def dict_sort_print(dic:dict, value:bool=True, reverse=True): |
| | """Sort a dictionary by value size and output""" |
| | idx = 1 if value else 0 |
| | sorted_lis = sorted(dic.items(), key=lambda x: x[idx], reverse=reverse) |
| | sorted_dic = {} |
| | for key, value in sorted_lis: |
| | sorted_dic[key] = value |
| | print(json.dumps(sorted_dic, indent=4, ensure_ascii=False)) |
| |
|
| | def clean_newlines(text: str) -> str: |
| | """ |
| | Clean lyric line breaks: |
| | 1. Keep line breaks after punctuation |
| | 2. Convert line breaks after non-punctuation β space |
| | 3. Fix extra spaces after English apostrophes |
| | 4. Merge redundant spaces |
| | 5. Preserve paragraph structure, ensure line breaks after punctuation |
| | """ |
| | if not text: |
| | return "" |
| |
|
| | text = text.strip() |
| |
|
| | |
| | text = text.replace('\r\n', '\n').replace('\r', '\n') |
| |
|
| | |
| | lines = [line.strip() for line in text.split('\n')] |
| | text = ' '.join(line for line in lines if line) |
| |
|
| | |
| | text = re.sub(r'([.,!?:;οΌγοΌοΌοΌ])\s*', r'\1\n', text) |
| |
|
| | |
| | text = re.sub(r"'\s+", "'", text) |
| |
|
| | |
| | text = re.sub(r'[ \t]+', ' ', text) |
| |
|
| | |
| | text = '\n'.join(line.strip() for line in text.split('\n')) |
| |
|
| | return text.strip() |
| |
|
| | |
| | def is_ch_char(char:str): |
| | """Determine if a single character is a Chinese character""" |
| | if len(char) != 1: |
| | return False |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | code = ord(char) |
| | |
| | |
| | if 0x4E00 <= code <= 0x9FFF: |
| | return True |
| | |
| | if 0x3400 <= code <= 0x4DBF: |
| | return True |
| | |
| | |
| | return False |
| |
|
| | |
| |
|
| | def load_txt(path:str) -> str: |
| | """Load a file as plain text""" |
| | with open(path, 'r') as file: |
| | content = file.read() |
| | return content |
| |
|
| | def load_json(path:str): |
| | """Load a JSON file""" |
| | if not os.path.exists(path): |
| | return {} |
| | with open(path, 'r') as file: |
| | data = json.load(file) |
| | return data |
| |
|
| | def load_jsonl(path:str, limit=-1) -> list[dict]: |
| | """Load a JSONL file""" |
| | data = [] |
| | with open(path, 'r') as file: |
| | for id, line in tqdm(enumerate(file), desc=f"Loading {path}"): |
| | if limit != -1 and id == limit: |
| | break |
| | data.append(json.loads(line)) |
| | return data |
| |
|
| | def save_json(data, path:str): |
| | """Save a JSON file""" |
| | with open(path, 'w', encoding='utf-8') as file: |
| | json.dump(data, file, ensure_ascii=False, indent=4) |
| |
|
| | def save_jsonl(data:list[dict], path:str, mode='w'): |
| | """Save a JSONL file""" |
| | with open(path, mode, encoding='utf-8') as file: |
| | for ele in tqdm(data, desc=f"Saving to {path}"): |
| | json.dump(ele, file, ensure_ascii=False) |
| | file.write("\n") |
| |
|
| | def audio_cut(input_path, mode:str, output_dir:str, segment_length:int=30000): |
| | """ |
| | Extract a segment of specified length from an audio file |
| | - mode: Cut type (random / middle) |
| | - output_dir: Output folder |
| | - segment_length: Segment length (milliseconds) |
| | """ |
| | assert mode in ['random', 'middle'] |
| |
|
| | |
| | if not os.path.exists(input_path): |
| | raise FileNotFoundError(f"Audio file not found: {input_path}") |
| | |
| | |
| | audio = AudioSegment.from_file(input_path) |
| | audio = audio.set_frame_rate(44100).set_channels(1) |
| | audio_duration = len(audio) |
| | |
| | |
| | if audio_duration <= segment_length: |
| | print(f"Warning: Audio too short ({audio_duration}ms), using full audio: {input_path}") |
| | segment = audio |
| | else: |
| | |
| | if mode == "random": |
| | |
| | max_start = max(0, audio_duration - segment_length) |
| | start = random.randint(0, max_start) |
| | end = start + segment_length |
| | else: |
| | |
| | middle_point = audio_duration // 2 |
| | start = max(0, middle_point - (segment_length // 2)) |
| | end = min(audio_duration, start + segment_length) |
| | |
| | |
| | if end > audio_duration: |
| | end = audio_duration |
| | start = end - segment_length |
| | elif start < 0: |
| | start = 0 |
| | end = segment_length |
| | |
| | |
| | start = max(0, min(start, audio_duration)) |
| | end = max(0, min(end, audio_duration)) |
| | |
| | if start >= end: |
| | raise ValueError(f"Invalid slice range: start={start}, end={end}, duration={audio_duration}") |
| | |
| | |
| | segment = audio[start:end] |
| | |
| | |
| | basename = pure_name(input_path) |
| | output_path = os.path.join(output_dir, f"seg_{basename}.wav") |
| | |
| | |
| | segment.export( |
| | output_path, |
| | format="wav", |
| | codec="pcm_s16le", |
| | parameters=["-acodec", "pcm_s16le"] |
| | ) |
| | return output_path |
| |
|
| | def format_meta(dir:str, show:bool=True) -> list[dict]: |
| | """Recursively get all audio paths (wav / mp3) in a folder and build JSONL""" |
| | if not os.path.isdir(dir): |
| | return [] |
| | dataset = [] |
| | if show: |
| | for name in tqdm(os.listdir(dir), desc=f"Formating {dir}"): |
| | path = os.path.join(dir, name) |
| | if os.path.isdir(path): |
| | dataset += format_meta(path, False) |
| | elif name.endswith('.mp3') or name.endswith('.wav'): |
| | dataset.append({"path": path}) |
| | else: |
| | for name in os.listdir(dir): |
| | path = os.path.join(dir, name) |
| | if os.path.isdir(path): |
| | dataset += format_meta(path, False) |
| | elif name.endswith('.mp3') or name.endswith('.wav'): |
| | dataset.append({"path": path}) |
| | return dataset |
| |
|
| | def dup_remove(raw_data:list[dict], save_path:str, key:str, seg:str): |
| | """ |
| | Remove already generated items from dataset |
| | - key is the primary key in raw dataset, foreign key in save |
| | - seg is the target field |
| | """ |
| | if not os.path.exists(save_path): |
| | print(f"Dup num: 0") |
| | return raw_data |
| | save_data = load_jsonl(save_path) |
| | keys = set() |
| | for ele in tqdm(save_data, desc="Constructing Dup Set"): |
| | if seg in ele: |
| | keys.add(ele[key]) |
| | rest_data = [] |
| | dup_count = 0 |
| | for ele in tqdm(raw_data, desc="Checking Dup"): |
| | if ele[key] not in keys: |
| | rest_data.append(ele) |
| | else: |
| | dup_count += 1 |
| | print(f"Dup num: {dup_count}") |
| | return rest_data |
| |
|
| | def tar_size_check(data_dir:str, subfixes:list[str], per:int, max_size:int): |
| | """ |
| | Determine the number of files that can fit in a block before compression (assuming uniform file sizes) |
| | - data_dir: Folder to compress |
| | - subfixes: File suffixes to compress (e.g., .mp3) |
| | - per: Check every N files on average |
| | - max_size: Maximum limit in GB |
| | """ |
| | names = sorted(list(os.listdir(data_dir))) |
| | count = 0 |
| | size_sum = 0 |
| | for name in tqdm(names, desc="Size Checking"): |
| | path = os.path.join(data_dir, name) |
| | subfix = os.path.splitext(name)[1] |
| | if subfix not in subfixes: |
| | continue |
| | count += 1 |
| | size_sum += os.path.getsize(path) |
| | if count % per == 0: |
| | gb_size = size_sum / 1024 / 1024 / 1024 |
| | if gb_size > max_size: |
| | break |
| | print(f"Count: {count}, Size: {gb_size:.2f}GB") |
| |
|
| | def tar_dir( |
| | data_dir:str, |
| | subfixes:list[str], |
| | save_dir:str, |
| | group_size:int, |
| | tmp_dir:str, |
| | mark:str, |
| | max_workers:int=10, |
| | ): |
| | """Compress files in a directory in chunks (non-recursive)""" |
| | names = sorted(list(os.listdir(data_dir))) |
| | file_num = len(names) |
| | for i in range(0, file_num, group_size): |
| | names_subset = names[i:i+group_size] |
| | size_sum = 0 |
| | name_path = os.path.join(tmp_dir, f"name_{i}_{mark}") |
| | with open(name_path, 'w', encoding='utf-8') as file: |
| | for name in tqdm(names_subset, desc=f"Counting Block {i}"): |
| | path = os.path.join(data_dir, name) |
| | subfix = os.path.splitext(path)[1] |
| | if subfix not in subfixes: |
| | continue |
| | file.write("./" + name + "\n") |
| | size_sum += os.path.getsize(path) |
| | gb_size = size_sum / 1024 / 1024 / 1024 |
| | print(f"Zipping block {i+1}, size: {gb_size:.2f}GB") |
| | |
| | tar_cmd = [ |
| | 'tar', |
| | '--no-recursion', |
| | '--files-from', str(name_path), |
| | '-cf', '-' |
| | ] |
| | pigz_cmd = ['pigz', '-p', str(max_workers), '-c'] |
| |
|
| | tar_process = subprocess.Popen(tar_cmd, stdout=subprocess.PIPE, cwd=data_dir) |
| | pigz_process = subprocess.Popen(pigz_cmd, stdin=tar_process.stdout, stdout=subprocess.PIPE, cwd=data_dir) |
| |
|
| | save_path = os.path.join(save_dir, f"block_{i}_{mark}.tar.gz") |
| | with open(save_path, 'wb') as out_file: |
| | while True: |
| | data = pigz_process.stdout.read(4096) |
| | if not data: |
| | break |
| | out_file.write(data) |
| | |
| | tar_process.wait() |
| | pigz_process.wait() |
| |
|
| | if tar_process.returncode == 0 and pigz_process.returncode == 0: |
| | print(f"Compression completed: {save_path}") |
| | else: |
| | print(f"Compression failed: tar return code={tar_process.returncode}, pigz return code={pigz_process.returncode}") |
| |
|
| | def music_avg_size(dir:str): |
| | """Average music size (MB), length (s)""" |
| | dataset = format_meta(dir) |
| | dataset = dataset[:50] |
| | size_sum = 0 |
| | length_sum = 0 |
| | for ele in tqdm(dataset, desc=f"Counting Music Size in {dir}"): |
| | path = ele['path'] |
| | audio = AudioSegment.from_file(path) |
| | length_sum += len(audio) / 1000.0 |
| | size_sum += os.path.getsize(path) |
| | size_avg = size_sum / len(dataset) / 1024 / 1024 |
| | length_avg = length_sum / len(dataset) |
| | return size_avg, length_avg |
| |
|
| | def get_sample(path:str, save_path:str="tmp.jsonl", num:int=100): |
| | """Get N records from a JSONL file""" |
| | if not os.path.exists(path): |
| | return |
| | if path.endswith(".jsonl"): |
| | dataset = load_jsonl(path) |
| | elif path.endswith(".json"): |
| | dataset = load_json(path) |
| | else: |
| | print(f"Unsupport file: {path}") |
| | return |
| | sub_dataset = random.sample(dataset, num) |
| | save_jsonl(sub_dataset, save_path) |
| |
|
| | def _get_field_one(path:str, field:str): |
| | """Process data from one path""" |
| | with open(path, 'r') as file: |
| | data = json.load(file) |
| | new_data = { |
| | "id": f"{data['song_id']}_{data['track_index']}", |
| | field: data[field] |
| | } |
| | return new_data |
| |
|
| | def get_field_suno(dir:str, save_path:str, field:str, max_workers:int=8): |
| | """Extract a specific field from scattered JSON files in suno""" |
| | paths = [] |
| | for name in tqdm(os.listdir(dir), desc="Getting names"): |
| | if not name.endswith(".json"): |
| | continue |
| | paths.append(os.path.join(dir, name)) |
| |
|
| | with ProcessPoolExecutor(max_workers=max_workers) as executor: |
| | futures = [executor.submit(_get_field_one, path, field) for path in paths] |
| | with open(save_path, 'w', encoding='utf-8') as file: |
| | with tqdm(total=len(paths), desc="Processing the JSONs") as pbar: |
| | for future in as_completed(futures): |
| | result = future.result() |
| | json.dump(result, file, ensure_ascii=False) |
| | file.write("\n") |
| | pbar.update(1) |
| |
|
| | def find_json(dir:str) -> list[str]: |
| | """Find JSONL / JSON files in a folder""" |
| | names = [] |
| | for name in tqdm(os.listdir(dir), desc="Finding JSON/JSONL"): |
| | if name.endswith(".json") or name.endswith(".jsonl"): |
| | names.append(name) |
| | return names |
| |
|
| | def show_dir(dir:str): |
| | """Display all contents in a directory""" |
| | if not os.path.isdir(dir): |
| | return |
| | for name in os.listdir(dir): |
| | print(name) |
| |
|
| | def _convert_mp3(path:str, dir:str): |
| | """Process a single audio file""" |
| | purename = pure_name(path) |
| | output_path = os.path.join(dir, purename + ".mp3") |
| | if os.path.exists(output_path): |
| | |
| | return "pass" |
| | try: |
| | audio = AudioSegment.from_file(path) |
| | except Exception: |
| | |
| | print(f"fail to load {path}") |
| | return "fail" |
| | audio.export(output_path, format='mp3') |
| | return "finish" |
| |
|
| | def convert_mp3(meta_path:str, dir:str, max_workers:int=10): |
| | """Convert all specified audio files to mp3 and save in specified directory""" |
| | os.makedirs(dir, exist_ok=True) |
| | dataset = load_jsonl(meta_path) |
| | pass_num = 0 |
| | finish_num = 0 |
| | with ThreadPoolExecutor(max_workers=max_workers) as executor: |
| | futures = [executor.submit(_convert_mp3, ele['path'], dir) for ele in dataset] |
| | with tqdm(total=len(dataset), desc=f"Converting {meta_path}") as pbar: |
| | for future in as_completed(futures): |
| | res = future.result() |
| | if res == "pass": |
| | pass_num += 1 |
| | else: |
| | finish_num += 1 |
| | pbar.update(1) |
| | print(f"Finish {finish_num}, Pass {pass_num}") |
| |
|
| | |
| |
|
| | def get_free_gpu() -> int: |
| | """Return the GPU ID with the least memory usage""" |
| | cmd = "nvidia-smi --query-gpu=index,memory.free --format=csv,noheader,nounits" |
| | result = subprocess.check_output(cmd.split()).decode().strip().split("\n") |
| |
|
| | free_list = [] |
| | for line in result: |
| | idx, free_mem = line.split(",") |
| | free_list.append((int(idx), int(free_mem))) |
| | |
| | |
| | free_list.sort(key=lambda x: x[1], reverse=True) |
| | return free_list[0][0] |
| |
|
| | |
| |
|
| | def compose_analyze(dataset:list[dict]): |
| | """Statistical analysis of music structure composition""" |
| | |
| | labels = defaultdict(int) |
| | for ele in tqdm(dataset): |
| | segments = ele['segments'] |
| | for segment in segments: |
| | label = segment['label'] |
| | labels[label] += 1 |
| | print(f"Number of labels: {len(labels)}") |
| | print(dict_sort_print(labels)) |
| |
|
| | |
| | label_combs = defaultdict(int) |
| | for ele in tqdm(dataset): |
| | segments = ele['segments'] |
| | labels = [] |
| | for segment in segments: |
| | label = segment['label'] |
| | labels.append(label) |
| | if len(labels) == 0: |
| | continue |
| | label_comb = " | ".join(labels) |
| | label_combs[label_comb] += 1 |
| | print(f"Number of combinations: {len(label_combs)}") |
| | print(dict_sort_print(label_combs)) |
| |
|
| | def _filter_tag(content:str) -> list[str]: |
| | """Split and format tag fields""" |
| | tags = [] |
| | raws = re.split(r'[,οΌ.]', content) |
| | for raw in raws: |
| | raw = raw.strip().lower() |
| | if raw == "": |
| | continue |
| | seg_pos = raw.find(":") |
| | if seg_pos != -1: |
| | |
| | tag = raw[seg_pos+1:].strip() |
| | else: |
| | tag = raw |
| | tags.append(tag) |
| | return tags |
| |
|
| | def tags_analyze(dataset:list[dict]): |
| | """Song tag analysis""" |
| | tag_count = defaultdict(int) |
| | for ele in tqdm(dataset, desc="Tag analyzing"): |
| | tags = _filter_tag(ele['style']) |
| | for tag in tags: |
| | tag_count[tag] += 1 |
| | print(f"Number of tags: {len(tag_count.keys())}") |
| | print(dict_sort_print(tag_count)) |