Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import argparse | |
| import torch | |
| import laion_clap | |
| import numpy as np | |
| import multiprocessing | |
| from tqdm import tqdm | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Labelling clap score for crpo dataset" | |
| ) | |
| parser.add_argument( | |
| "--num_samples", type=int, default=5, | |
| help="Number of audio samples per prompt" | |
| ) | |
| parser.add_argument( | |
| "--json_path", type=str, required=True, | |
| help="Path to input JSON file" | |
| ) | |
| parser.add_argument( | |
| "--output_dir", type=str, required=True, | |
| help="Directory to save the final JSON with CLAP scores" | |
| ) | |
| return parser.parse_args() | |
| #python3 label_clap.py --json_path=/mnt/data/chiayu/crpo/crpo_iteration1/results.json --output_dir=/mnt/data/chiayu/crpo/crpo_iteration1 | |
| def compute_clap(model, audio_files, text_data): | |
| # Compute audio and text embeddings, then compute the dot product (CLAP score) | |
| audio_embed = model.get_audio_embedding_from_filelist(x=audio_files, use_tensor=True) | |
| text_embed = model.get_text_embedding(text_data, use_tensor=True) | |
| return audio_embed @ text_embed.T | |
| def process_chunk(args, chunk, gpu_id, return_dict, process_id): | |
| """ | |
| Process a chunk of the data on a specific GPU. | |
| Loads the CLAP model on the designated device, then for each item in the chunk, | |
| computes the CLAP scores and attaches them to the data. | |
| """ | |
| try: | |
| device = f"cuda:{gpu_id}" | |
| torch.cuda.set_device(device) | |
| print(f"Process {process_id}: Using device {device}") | |
| # Initialize the CLAP model on this GPU | |
| model = laion_clap.CLAP_Module(enable_fusion=False) | |
| model.to(device) | |
| model.load_ckpt() | |
| model.eval() | |
| for j, item in enumerate(tqdm(chunk, desc=f"GPU {gpu_id}")): | |
| # Each item is assumed to be a list of samples. | |
| # Skip if already computed. | |
| if 'clap_score' in item[0]: | |
| continue | |
| # Collect audio file paths and text data (using the first caption) | |
| audio_files = [item[i]['path'] for i in range(args.num_samples)] | |
| text_data = [item[0]['captions']] | |
| try: | |
| clap_scores = compute_clap(model, audio_files, text_data) | |
| except Exception as e: | |
| print(f"Error processing item index {j} on GPU {gpu_id}: {e}") | |
| continue | |
| # Attach the computed score to each sample in the item | |
| for k in range(args.num_samples): | |
| item[k]['clap_score'] = np.round(clap_scores[k].item(), 3) | |
| return_dict[process_id] = chunk | |
| print(f"Process {process_id}: Completed processing on GPU {gpu_id}") | |
| except Exception as e: | |
| print(f"Process {process_id}: Error on GPU {gpu_id}: {e}") | |
| return_dict[process_id] = [] | |
| def split_into_chunks(data, num_chunks): | |
| """ | |
| Splits data into num_chunks approximately equal parts. | |
| """ | |
| avg = len(data) // num_chunks | |
| chunks = [] | |
| for i in range(num_chunks): | |
| start = i * avg | |
| # Ensure the last chunk takes the remainder of the data | |
| end = (i + 1) * avg if i != num_chunks - 1 else len(data) | |
| chunks.append(data[start:end]) | |
| return chunks | |
| def main(): | |
| args = parse_args() | |
| # Load data from JSON and slice by start/end if provided | |
| with open(args.json_path, 'r') as f: | |
| data = json.load(f) | |
| # Check GPU availability and split data accordingly | |
| num_gpus = torch.cuda.device_count() | |
| print(f"Found {num_gpus} GPUs. Splitting data into {num_gpus} chunks.") | |
| chunks = split_into_chunks(data, num_gpus) | |
| # Prepare output directory | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| # Create a manager dict to collect results from all processes | |
| manager = multiprocessing.Manager() | |
| return_dict = manager.dict() | |
| processes = [] | |
| for i in range(num_gpus): | |
| p = multiprocessing.Process( | |
| target=process_chunk, | |
| args=(args, chunks[i], i, return_dict, i) | |
| ) | |
| processes.append(p) | |
| p.start() | |
| print(f"Started process {i} on GPU {i}") | |
| for p in processes: | |
| p.join() | |
| print(f"Process {p.pid} has finished.") | |
| # Aggregate all chunks back into a single list | |
| combined_data = [] | |
| for i in range(num_gpus): | |
| combined_data.extend(return_dict[i]) | |
| # Save the combined results to a single JSON file | |
| output_file = f"{args.output_dir}/clap_scores.json" | |
| with open(output_file, 'w') as f: | |
| json.dump(combined_data, f) | |
| print(f"All CLAP scores have been computed and saved to {output_file}") | |
| max_item = [max(x, key=lambda item: item['clap_score']) for x in combined_data] | |
| min_item = [min(x, key=lambda item: item['clap_score']) for x in combined_data] | |
| crpo_dataset = [] | |
| for chosen,reject in zip(max_item,min_item): | |
| crpo_dataset.append({"captions": chosen['captions'], | |
| "duration": chosen['duration'], | |
| "chosen": chosen['path'], | |
| "reject": reject['path']}) | |
| with open(f"{args.output_dir}/train.json",'w') as f: | |
| json.dump(crpo_dataset,f) | |
| if __name__ == '__main__': | |
| multiprocessing.set_start_method('spawn') | |
| main() | |