Jacong's picture
Upload 96 files
aa9be1e verified
import os
import re
import json
import random
import tarfile
import subprocess
import json_repair
from tqdm import tqdm
from pathlib import Path
from pydub import AudioSegment
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
# ===== Macros =====
BASE_DIR = Path(__file__).parent
# ===== Helper Functions =====
def pure_name(path:str):
"""Get the original name of a file path (without extension)"""
basename = os.path.basename(path)
dot_pos = basename.rfind('.')
if dot_pos == -1:
return basename
return basename[:dot_pos]
def extract_json(text: str) -> tuple[bool, dict]:
"""Extract and repair JSON data from text (enhanced error-tolerant version)
Features:
1. Automatically identify code block markers (```json``` or ```)
2. Fix common JSON errors (mismatched quotes, trailing commas, etc.)
3. Support lenient parsing mode
Returns: (success, parsed dictionary)
"""
# Preprocessing: extract possible JSON content area
content = text
# Case 1: Check ```json``` code block
if '```json' in text:
start = text.find('```json')
end = text.find('```', start + 6)
content = text[start + 6:end].strip()
# Case 2: Check regular ``` code block
elif '```' in text:
start = text.find('```')
end = text.find('```', start + 3)
content = text[start + 3:end].strip()
# Clean common interference items in content
content = re.sub(r'^[^{[]*', '', content) # Remove unstructured content before JSON
content = re.sub(r'[^}\]]*$', '', content) # Remove unstructured content after JSON
# Try standard parsing
try:
json_data = json.loads(content)
return True, json_data
except json.JSONDecodeError as e:
standard_error = e
# Try to repair with json_repair
try:
repaired = json_repair.repair_json(content)
json_data = json.loads(repaired)
return True, json_data
except Exception as e:
repair_error = e
return False, {
"standard_error": standard_error,
"repair_error": repair_error
}
def path_join(dir, name):
return os.path.join(dir, name)
def dict_sort_print(dic:dict, value:bool=True, reverse=True):
"""Sort a dictionary by value size and output"""
idx = 1 if value else 0
sorted_lis = sorted(dic.items(), key=lambda x: x[idx], reverse=reverse)
sorted_dic = {}
for key, value in sorted_lis:
sorted_dic[key] = value
print(json.dumps(sorted_dic, indent=4, ensure_ascii=False))
def clean_newlines(text: str) -> str:
"""
Clean lyric line breaks:
1. Keep line breaks after punctuation
2. Convert line breaks after non-punctuation β†’ space
3. Fix extra spaces after English apostrophes
4. Merge redundant spaces
5. Preserve paragraph structure, ensure line breaks after punctuation
"""
if not text:
return ""
text = text.strip()
# First unify line breaks to \n
text = text.replace('\r\n', '\n').replace('\r', '\n')
# Merge non-empty lines into one sentence (remove original line breaks first)
lines = [line.strip() for line in text.split('\n')]
text = ' '.join(line for line in lines if line)
# Add line break after sentence-ending punctuation (Chinese and English punctuation)
text = re.sub(r'([.,!?:;οΌŒγ€‚οΌοΌŸοΌ›])\s*', r'\1\n', text)
# Fix spaces after English apostrophes
text = re.sub(r"'\s+", "'", text)
# Merge redundant spaces
text = re.sub(r'[ \t]+', ' ', text)
# Remove leading and trailing spaces from lines
text = '\n'.join(line.strip() for line in text.split('\n'))
return text.strip()
# ===== Detection Functions =====
def is_ch_char(char:str):
"""Determine if a single character is a Chinese character"""
if len(char) != 1:
return False
# Unicode ranges for Chinese characters
# 1. Basic Chinese: 0x4E00-0x9FFF
# 2. Extension A: 0x3400-0x4DBF
# 3. Extension B: 0x20000-0x2A6DF
# 4. Extension C: 0x2A700-0x2B73F
# 5. Extension D: 0x2B740-0x2B81F
# 6. Extension E: 0x2B820-0x2CEAF
code = ord(char)
# Common check (covers most cases)
if 0x4E00 <= code <= 0x9FFF:
return True
# Extension A
if 0x3400 <= code <= 0x4DBF:
return True
# Other extensions not considered for now
return False
# ===== File Operations =====
def load_txt(path:str) -> str:
"""Load a file as plain text"""
with open(path, 'r') as file:
content = file.read()
return content
def load_json(path:str):
"""Load a JSON file"""
if not os.path.exists(path):
return {}
with open(path, 'r') as file:
data = json.load(file)
return data
def load_jsonl(path:str, limit=-1) -> list[dict]:
"""Load a JSONL file"""
data = []
with open(path, 'r') as file:
for id, line in tqdm(enumerate(file), desc=f"Loading {path}"):
if limit != -1 and id == limit:
break
data.append(json.loads(line))
return data
def save_json(data, path:str):
"""Save a JSON file"""
with open(path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
def save_jsonl(data:list[dict], path:str, mode='w'):
"""Save a JSONL file"""
with open(path, mode, encoding='utf-8') as file:
for ele in tqdm(data, desc=f"Saving to {path}"):
json.dump(ele, file, ensure_ascii=False)
file.write("\n")
def audio_cut(input_path, mode:str, output_dir:str, segment_length:int=30000):
"""
Extract a segment of specified length from an audio file
- mode: Cut type (random / middle)
- output_dir: Output folder
- segment_length: Segment length (milliseconds)
"""
assert mode in ['random', 'middle']
# Check if file exists
if not os.path.exists(input_path):
raise FileNotFoundError(f"Audio file not found: {input_path}")
# Load audio file
audio = AudioSegment.from_file(input_path)
audio = audio.set_frame_rate(44100).set_channels(1) # Set sample rate and channels
audio_duration = len(audio) # Duration control
# If audio length is less than target segment length, use entire audio
if audio_duration <= segment_length:
print(f"Warning: Audio too short ({audio_duration}ms), using full audio: {input_path}")
segment = audio
else:
# Calculate slice position based on mode
if mode == "random":
# Random cut
max_start = max(0, audio_duration - segment_length)
start = random.randint(0, max_start)
end = start + segment_length
else:
# Cut from middle
middle_point = audio_duration // 2
start = max(0, middle_point - (segment_length // 2))
end = min(audio_duration, start + segment_length)
# If cutting from middle would exceed boundaries, adjust start position
if end > audio_duration:
end = audio_duration
start = end - segment_length
elif start < 0:
start = 0
end = segment_length
# Ensure slice range is valid
start = max(0, min(start, audio_duration))
end = max(0, min(end, audio_duration))
if start >= end:
raise ValueError(f"Invalid slice range: start={start}, end={end}, duration={audio_duration}")
# Execute slice
segment = audio[start:end]
# Generate output path
basename = pure_name(input_path)
output_path = os.path.join(output_dir, f"seg_{basename}.wav")
# Save segment
segment.export(
output_path,
format="wav",
codec="pcm_s16le", # 16-bit little-endian encoding
parameters=["-acodec", "pcm_s16le"] # ffmpeg parameters
)
return output_path
def format_meta(dir:str, show:bool=True) -> list[dict]:
"""Recursively get all audio paths (wav / mp3) in a folder and build JSONL"""
if not os.path.isdir(dir):
return []
dataset = []
if show:
for name in tqdm(os.listdir(dir), desc=f"Formating {dir}"):
path = os.path.join(dir, name)
if os.path.isdir(path):
dataset += format_meta(path, False)
elif name.endswith('.mp3') or name.endswith('.wav'):
dataset.append({"path": path})
else:
for name in os.listdir(dir):
path = os.path.join(dir, name)
if os.path.isdir(path):
dataset += format_meta(path, False)
elif name.endswith('.mp3') or name.endswith('.wav'):
dataset.append({"path": path})
return dataset
def dup_remove(raw_data:list[dict], save_path:str, key:str, seg:str):
"""
Remove already generated items from dataset
- key is the primary key in raw dataset, foreign key in save
- seg is the target field
"""
if not os.path.exists(save_path):
print(f"Dup num: 0")
return raw_data
save_data = load_jsonl(save_path)
keys = set()
for ele in tqdm(save_data, desc="Constructing Dup Set"):
if seg in ele:
keys.add(ele[key])
rest_data = []
dup_count = 0
for ele in tqdm(raw_data, desc="Checking Dup"):
if ele[key] not in keys:
rest_data.append(ele)
else:
dup_count += 1
print(f"Dup num: {dup_count}")
return rest_data
def tar_size_check(data_dir:str, subfixes:list[str], per:int, max_size:int):
"""
Determine the number of files that can fit in a block before compression (assuming uniform file sizes)
- data_dir: Folder to compress
- subfixes: File suffixes to compress (e.g., .mp3)
- per: Check every N files on average
- max_size: Maximum limit in GB
"""
names = sorted(list(os.listdir(data_dir)))
count = 0
size_sum = 0
for name in tqdm(names, desc="Size Checking"):
path = os.path.join(data_dir, name)
subfix = os.path.splitext(name)[1]
if subfix not in subfixes:
continue
count += 1
size_sum += os.path.getsize(path)
if count % per == 0:
gb_size = size_sum / 1024 / 1024 / 1024
if gb_size > max_size:
break
print(f"Count: {count}, Size: {gb_size:.2f}GB")
def tar_dir(
data_dir:str,
subfixes:list[str],
save_dir:str,
group_size:int,
tmp_dir:str,
mark:str,
max_workers:int=10,
):
"""Compress files in a directory in chunks (non-recursive)"""
names = sorted(list(os.listdir(data_dir)))
file_num = len(names)
for i in range(0, file_num, group_size):
names_subset = names[i:i+group_size]
size_sum = 0
name_path = os.path.join(tmp_dir, f"name_{i}_{mark}")
with open(name_path, 'w', encoding='utf-8') as file:
for name in tqdm(names_subset, desc=f"Counting Block {i}"):
path = os.path.join(data_dir, name)
subfix = os.path.splitext(path)[1]
if subfix not in subfixes:
continue
file.write("./" + name + "\n")
size_sum += os.path.getsize(path)
gb_size = size_sum / 1024 / 1024 / 1024
print(f"Zipping block {i+1}, size: {gb_size:.2f}GB")
tar_cmd = [
'tar',
'--no-recursion',
'--files-from', str(name_path),
'-cf', '-'
]
pigz_cmd = ['pigz', '-p', str(max_workers), '-c']
tar_process = subprocess.Popen(tar_cmd, stdout=subprocess.PIPE, cwd=data_dir)
pigz_process = subprocess.Popen(pigz_cmd, stdin=tar_process.stdout, stdout=subprocess.PIPE, cwd=data_dir)
save_path = os.path.join(save_dir, f"block_{i}_{mark}.tar.gz")
with open(save_path, 'wb') as out_file:
while True:
data = pigz_process.stdout.read(4096)
if not data:
break
out_file.write(data)
tar_process.wait()
pigz_process.wait()
if tar_process.returncode == 0 and pigz_process.returncode == 0:
print(f"Compression completed: {save_path}")
else:
print(f"Compression failed: tar return code={tar_process.returncode}, pigz return code={pigz_process.returncode}")
def music_avg_size(dir:str):
"""Average music size (MB), length (s)"""
dataset = format_meta(dir)
dataset = dataset[:50]
size_sum = 0
length_sum = 0
for ele in tqdm(dataset, desc=f"Counting Music Size in {dir}"):
path = ele['path']
audio = AudioSegment.from_file(path)
length_sum += len(audio) / 1000.0
size_sum += os.path.getsize(path)
size_avg = size_sum / len(dataset) / 1024 / 1024
length_avg = length_sum / len(dataset)
return size_avg, length_avg
def get_sample(path:str, save_path:str="tmp.jsonl", num:int=100):
"""Get N records from a JSONL file"""
if not os.path.exists(path):
return
if path.endswith(".jsonl"):
dataset = load_jsonl(path)
elif path.endswith(".json"):
dataset = load_json(path)
else:
print(f"Unsupport file: {path}")
return
sub_dataset = random.sample(dataset, num)
save_jsonl(sub_dataset, save_path)
def _get_field_one(path:str, field:str):
"""Process data from one path"""
with open(path, 'r') as file:
data = json.load(file)
new_data = {
"id": f"{data['song_id']}_{data['track_index']}",
field: data[field]
}
return new_data
def get_field_suno(dir:str, save_path:str, field:str, max_workers:int=8):
"""Extract a specific field from scattered JSON files in suno"""
paths = []
for name in tqdm(os.listdir(dir), desc="Getting names"):
if not name.endswith(".json"):
continue
paths.append(os.path.join(dir, name))
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(_get_field_one, path, field) for path in paths]
with open(save_path, 'w', encoding='utf-8') as file:
with tqdm(total=len(paths), desc="Processing the JSONs") as pbar:
for future in as_completed(futures):
result = future.result()
json.dump(result, file, ensure_ascii=False)
file.write("\n")
pbar.update(1)
def find_json(dir:str) -> list[str]:
"""Find JSONL / JSON files in a folder"""
names = []
for name in tqdm(os.listdir(dir), desc="Finding JSON/JSONL"):
if name.endswith(".json") or name.endswith(".jsonl"):
names.append(name)
return names
def show_dir(dir:str):
"""Display all contents in a directory"""
if not os.path.isdir(dir):
return
for name in os.listdir(dir):
print(name)
def _convert_mp3(path:str, dir:str):
"""Process a single audio file"""
purename = pure_name(path)
output_path = os.path.join(dir, purename + ".mp3")
if os.path.exists(output_path):
# Already completed
return "pass"
try:
audio = AudioSegment.from_file(path)
except Exception:
# Failed to read file
print(f"fail to load {path}")
return "fail"
audio.export(output_path, format='mp3')
return "finish"
def convert_mp3(meta_path:str, dir:str, max_workers:int=10):
"""Convert all specified audio files to mp3 and save in specified directory"""
os.makedirs(dir, exist_ok=True)
dataset = load_jsonl(meta_path)
pass_num = 0
finish_num = 0
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(_convert_mp3, ele['path'], dir) for ele in dataset]
with tqdm(total=len(dataset), desc=f"Converting {meta_path}") as pbar:
for future in as_completed(futures):
res = future.result()
if res == "pass":
pass_num += 1
else:
finish_num += 1
pbar.update(1)
print(f"Finish {finish_num}, Pass {pass_num}")
# ===== GPU and Models =====
def get_free_gpu() -> int:
"""Return the GPU ID with the least memory usage"""
cmd = "nvidia-smi --query-gpu=index,memory.free --format=csv,noheader,nounits"
result = subprocess.check_output(cmd.split()).decode().strip().split("\n")
free_list = []
for line in result:
idx, free_mem = line.split(",")
free_list.append((int(idx), int(free_mem))) # (GPU id, free memory MiB)
# Sort by remaining memory
free_list.sort(key=lambda x: x[1], reverse=True)
return free_list[0][0]
# ===== Data Analysis =====
def compose_analyze(dataset:list[dict]):
"""Statistical analysis of music structure composition"""
# Label count
labels = defaultdict(int)
for ele in tqdm(dataset):
segments = ele['segments']
for segment in segments:
label = segment['label']
labels[label] += 1
print(f"Number of labels: {len(labels)}")
print(dict_sort_print(labels))
# Different combinations
label_combs = defaultdict(int)
for ele in tqdm(dataset):
segments = ele['segments']
labels = []
for segment in segments:
label = segment['label']
labels.append(label)
if len(labels) == 0:
continue
label_comb = " | ".join(labels)
label_combs[label_comb] += 1
print(f"Number of combinations: {len(label_combs)}")
print(dict_sort_print(label_combs))
def _filter_tag(content:str) -> list[str]:
"""Split and format tag fields"""
tags = []
raws = re.split(r'[,,.]', content)
for raw in raws:
raw = raw.strip().lower() # Remove spaces and convert to lowercase
if raw == "":
continue
seg_pos = raw.find(":")
if seg_pos != -1:
# If colon exists, only take the part after it
tag = raw[seg_pos+1:].strip()
else:
tag = raw
tags.append(tag)
return tags
def tags_analyze(dataset:list[dict]):
"""Song tag analysis"""
tag_count = defaultdict(int)
for ele in tqdm(dataset, desc="Tag analyzing"):
tags = _filter_tag(ele['style'])
for tag in tags:
tag_count[tag] += 1
print(f"Number of tags: {len(tag_count.keys())}")
print(dict_sort_print(tag_count))