File size: 8,424 Bytes
12b2634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
input_dir = ''  # interleaved abc folder
output_dir = ''  # feature folder

import os
import json
import random
import torch
import numpy as np
from tqdm import tqdm
from config import *
from utils import *
from samplings import *
from accelerate import Accelerator
from transformers import BertConfig, AutoTokenizer
import argparse


normalize = True

os.makedirs("logs", exist_ok=True)
for file in ["logs/files_extract_clamp2.json",
             "logs/files_shuffle_extract_clamp2.json",
             "logs/log_extract_clamp2.txt",
             "logs/pass_extract_clamp2.txt",
             "logs/skip_extract_clamp2.txt"]:
    if os.path.exists(file):
        os.remove(file)

files = []
for root, dirs, fs in os.walk(input_dir):
    for f in fs:
        if f.endswith(".txt") or f.endswith(".abc") or f.endswith(".mtf"):
            files.append(os.path.join(root, f))
print(f"Found {len(files)} files in total")
with open("logs/files_extract_clamp2.json", "w", encoding="utf-8") as f:
    json.dump(files, f)
random.shuffle(files)
with open("logs/files_shuffle_extract_clamp2.json", "w", encoding="utf-8") as f:
    json.dump(files, f) 

accelerator = Accelerator()
device = accelerator.device
print("Using device:", device)
with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
    f.write("Using device: " + str(device) + "\n")

m3_config = BertConfig(vocab_size=1,
                        hidden_size=M3_HIDDEN_SIZE,
                        num_hidden_layers=PATCH_NUM_LAYERS,
                        num_attention_heads=M3_HIDDEN_SIZE//64,
                        intermediate_size=M3_HIDDEN_SIZE*4,
                        max_position_embeddings=PATCH_LENGTH)
model = CLaMP2Model(m3_config,
                    text_model_name=TEXT_MODEL_NAME,
                    hidden_size=CLAMP2_HIDDEN_SIZE,
                    load_m3=CLAMP2_LOAD_M3)
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
patchilizer = M3Patchilizer()

# print parameter number
print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))

model.eval()
checkpoint = torch.load(CLAMP2_WEIGHTS_PATH, map_location='cpu', weights_only=True)
print(f"Successfully Loaded CLaMP 2 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
model.load_state_dict(checkpoint['model'])

def extract_feature(filename, get_normalized=normalize):
    with open(filename, "r", encoding="utf-8") as f:
        lines = f.readlines()

    filtered_lines = []
    for line in lines:
        if line.startswith('%') and not line.startswith('%%'):
            pass
        else:
            filtered_lines.append(line)

    item = ''.join(filtered_lines)

    if filename.endswith(".txt"):
        item = list(set(item.split("\n")))
        item = "\n".join(item)
        item = item.split("\n")
        item = [c for c in item if len(c) > 0]
        item = tokenizer.sep_token.join(item)
        input_data = tokenizer(item, return_tensors="pt")
        input_data = input_data['input_ids'].squeeze(0)
        max_input_length = MAX_TEXT_LENGTH
    else:
        input_data = patchilizer.encode(item, add_special_patches=True)
        input_data = torch.tensor(input_data)
        max_input_length = PATCH_LENGTH

    segment_list = []
    for i in range(0, len(input_data), max_input_length):
        segment_list.append(input_data[i:i+max_input_length])
    segment_list[-1] = input_data[-max_input_length:]

    last_hidden_states_list = []

    for input_segment in segment_list:
        input_masks = torch.tensor([1]*input_segment.size(0))
        if filename.endswith(".txt"):
            pad_indices = torch.ones(MAX_TEXT_LENGTH - input_segment.size(0)).long() * tokenizer.pad_token_id
        else:
            pad_indices = torch.ones((PATCH_LENGTH - input_segment.size(0), PATCH_SIZE)).long() * patchilizer.pad_token_id
        input_masks = torch.cat((input_masks, torch.zeros(max_input_length - input_segment.size(0))), 0)
        input_segment = torch.cat((input_segment, pad_indices), 0)

        if filename.endswith(".txt"):
            last_hidden_states = model.get_text_features(text_inputs=input_segment.unsqueeze(0).to(device),
                                                         text_masks=input_masks.unsqueeze(0).to(device),
                                                         get_normalized=get_normalized)
        else:
            last_hidden_states = model.get_music_features(music_inputs=input_segment.unsqueeze(0).to(device),
                                                          music_masks=input_masks.unsqueeze(0).to(device),
                                                          get_normalized=get_normalized)
        if not get_normalized:
            last_hidden_states = last_hidden_states[:, :input_masks.sum().long().item(), :]
        last_hidden_states_list.append(last_hidden_states)

    if not get_normalized:
        last_hidden_states_list = [last_hidden_states[0] for last_hidden_states in last_hidden_states_list]
        last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(input_data)%max_input_length):]
        last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
    else:
        full_chunk_cnt = len(input_data) // max_input_length
        remain_chunk_len = len(input_data) % max_input_length
        if remain_chunk_len == 0:
            feature_weights = torch.tensor([max_input_length] * full_chunk_cnt, device=device).view(-1, 1)
        else:
            feature_weights = torch.tensor([max_input_length] * full_chunk_cnt + [remain_chunk_len], device=device).view(-1, 1)
        
        last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
        last_hidden_states_list = last_hidden_states_list * feature_weights
        last_hidden_states_list = last_hidden_states_list.sum(dim=0) / feature_weights.sum()

    return last_hidden_states_list

def process_directory(input_dir, output_dir, files):
    print(f"Found {len(files)} files in total")
    with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
        f.write("Found " + str(len(files)) + " files in total\n")

    # calculate the number of files to process per GPU
    num_files_per_gpu = len(files) // accelerator.num_processes

    # calculate the start and end index for the current GPU
    start_idx = accelerator.process_index * num_files_per_gpu
    end_idx = start_idx + num_files_per_gpu
    if accelerator.process_index == accelerator.num_processes - 1:
        end_idx = len(files)

    files_to_process = files[start_idx:end_idx]

    # process the files
    for file in tqdm(files_to_process):
        output_subdir = output_dir + os.path.dirname(file)[len(input_dir):]
        try:
            os.makedirs(output_subdir, exist_ok=True)
        except Exception as e:
            print(output_subdir + " can not be created\n" + str(e))
            with open("logs/log_extract_clamp.txt", "a") as f:
                f.write(output_subdir + " can not be created\n" + str(e) + "\n")

        output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy")

        if os.path.exists(output_file):
            print(f"Skipping {file}, output already exists")
            with open("logs/skip_extract_clamp2.txt", "a", encoding="utf-8") as f:
                f.write(file + "\n")
            continue

        try:
            with torch.no_grad():
                features = extract_feature(file).unsqueeze(0)
            np.save(output_file, features.detach().cpu().numpy())
            with open("logs/pass_extract_clamp2.txt", "a", encoding="utf-8") as f:
                f.write(file + "\n")
        except Exception as e:
            print(f"Failed to process {file}: {e}")
            with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
                f.write("Failed to process " + file + ": " + str(e) + "\n")

with open("logs/files_shuffle_extract_clamp2.json", "r", encoding="utf-8") as f:
    files = json.load(f)

# process the files
process_directory(input_dir, output_dir, files)

with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
    f.write("GPU ID: " + str(device) + "\n")