|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
def remove_1(points): |
|
|
filtered_points = [point for point in points if point[2] != 1] |
|
|
return filtered_points |
|
|
|
|
|
|
|
|
class CompareHelper: |
|
|
def __init__(self, data): |
|
|
self.data = data |
|
|
|
|
|
def __lt__(self, other): |
|
|
return self.data[0] < other.data[0] |
|
|
|
|
|
|
|
|
def get_duration_in_interval(chord, start_interval, end_interval): |
|
|
"""Interval ๋ด์์ chord์ ์ง์ ์๊ฐ์ ๋ฐํํฉ๋๋ค.""" |
|
|
return min(chord['end'], end_interval) - max(chord['start'], start_interval) |
|
|
|
|
|
|
|
|
def shift_image_optimized(image, x_shift, y_shift): |
|
|
|
|
|
_, _, height, width = image.size() |
|
|
|
|
|
|
|
|
shifted_image = torch.roll(image, shifts=(x_shift, y_shift), dims=(3, 2)) |
|
|
|
|
|
|
|
|
if x_shift > 0: |
|
|
shifted_image[:, :, :, :x_shift] = 0 |
|
|
elif x_shift < 0: |
|
|
shifted_image[:, :, :, x_shift:] = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return shifted_image |
|
|
|
|
|
|
|
|
def algorithmic_collate3(batch): |
|
|
imgs, labels, points = zip(*batch) |
|
|
return_images = [] |
|
|
return_labels = [] |
|
|
return_points = [] |
|
|
|
|
|
for img_list in imgs: |
|
|
return_images.extend(img_list) |
|
|
for label in labels: |
|
|
return_labels.extend(label) |
|
|
for point in points: |
|
|
return_points.extend(point) |
|
|
|
|
|
return return_images, return_labels, return_points |
|
|
|
|
|
def quantize_image(image): |
|
|
""" |
|
|
Quantize the given image tensor. |
|
|
|
|
|
:param image: torch.Tensor, shape [1, 128, 192], binary values |
|
|
:return: torch.Tensor, shape [1, 128, 64], quantized values |
|
|
""" |
|
|
|
|
|
quantized_image = torch.zeros(1, 128, 64) |
|
|
|
|
|
|
|
|
for i in range(64): |
|
|
|
|
|
|
|
|
|
|
|
if i == 0: |
|
|
start_idx = 0 |
|
|
end_idx = start_idx + 2 |
|
|
|
|
|
else: |
|
|
start_idx = i * 3 - 1 |
|
|
end_idx = start_idx + 3 |
|
|
|
|
|
|
|
|
quantized_image[:, :, i] = (image[:, :, start_idx:end_idx].sum(dim=2) > 0).float() |
|
|
|
|
|
return quantized_image |
|
|
|
|
|
def piano_roll_to_chroma(piano_roll): |
|
|
""" |
|
|
Convert a binary piano roll tensor to a binary chroma tensor. |
|
|
|
|
|
Parameters: |
|
|
piano_roll (torch.Tensor): The binary piano roll tensor with shape |
|
|
(batch_size, num_channels, num_pitches, num_frames). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The binary chroma tensor with shape |
|
|
(batch_size, num_channels, 12, num_frames). |
|
|
""" |
|
|
if piano_roll.shape[2] == 12: |
|
|
return piano_roll |
|
|
|
|
|
|
|
|
binary_piano_roll = (piano_roll > 0).float() |
|
|
|
|
|
|
|
|
chroma = torch.zeros( |
|
|
(binary_piano_roll.shape[0], binary_piano_roll.shape[1], 12, binary_piano_roll.shape[3]), |
|
|
device=binary_piano_roll.device, |
|
|
) |
|
|
|
|
|
|
|
|
for i in range(12): |
|
|
chroma[:, :, i, :] = binary_piano_roll[:, :, i::12, :].max(dim=2).values |
|
|
|
|
|
return chroma |
|
|
|
|
|
def calculate_correlation(tensor1, tensor2, max_shift,device): |
|
|
|
|
|
|
|
|
max_correlation = torch.full((tensor1.size(0), tensor2.size(0)), float('-inf')).to(device) |
|
|
|
|
|
for shift in range(-max_shift, max_shift + 1): |
|
|
|
|
|
|
|
|
shifted_tensor2 = torch.roll(tensor2, shifts=shift, dims=1) |
|
|
|
|
|
|
|
|
|
|
|
tensor1_norm = tensor1 / tensor1.norm(dim=1, keepdim=True) |
|
|
tensor2_norm = shifted_tensor2 / tensor2.norm(dim=1, keepdim=True) |
|
|
|
|
|
|
|
|
cosine_similarity = torch.mm(tensor1_norm, tensor2_norm.t()) |
|
|
max_correlation = torch.max(max_correlation, cosine_similarity) |
|
|
""" |
|
|
|
|
|
# L1 ์ฝ์ฌ์ธ ์ ์ฌ๋๋ผ ํด์ผํ๋..? ์ฌํผ ๋จ์ ๋
ธํธ ์ ์ฌ๋ ๊ณ์ฐ |
|
|
tensor1_expanded = tensor1.unsqueeze(1) |
|
|
tensor2_expanded = shifted_tensor2.unsqueeze(0) |
|
|
both_one = tensor1_expanded * tensor2_expanded |
|
|
|
|
|
# ๋ ๋ฒกํฐ ๋ชจ๋์์ 1์ธ ์์์ ๊ฐ์ ๋ฐ 1์ธ ์์์ ์ดํฉ ๊ณ์ฐ |
|
|
both_one_sum = both_one.sum(dim=2) |
|
|
total_one_sum = tensor1_expanded.sum(dim=2) + tensor2_expanded.sum(dim=2) |
|
|
metric_matrix = both_one_sum / total_one_sum |
|
|
max_correlation = torch.max(max_correlation, metric_matrix) |
|
|
""" |
|
|
|
|
|
return max_correlation |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infos_to_pianorolls(info, use_all): |
|
|
pianorolls={} |
|
|
|
|
|
CONLON_points={} |
|
|
|
|
|
|
|
|
|
|
|
vocal_pianorolls={} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocal_CONLON_points={} |
|
|
|
|
|
|
|
|
start_points = infos_to_startpoint(info, use_all) |
|
|
|
|
|
|
|
|
shift_val = 0 |
|
|
for idx, i in enumerate(start_points): |
|
|
|
|
|
""" |
|
|
cleansed_bass={} |
|
|
for key, bar in info.bass_info.items(): |
|
|
if len(bar)>0: |
|
|
bar=np.array(bar) |
|
|
remain_notes=[] |
|
|
to_quantize = 16 # 16๋ถ ์ํ ํ๋๋น ์ต๋ 1๊ฐ์ Note๋ฅผ ๋จ๊น๋๋ค. |
|
|
idx_quantize = 48/to_quantize |
|
|
for j in range(to_quantize): |
|
|
bass_idx = np.where((bar[:,4]//idx_quantize == j)) |
|
|
notes = bar[bass_idx] |
|
|
best_note = get_best_bass(chart_info, notes) |
|
|
if best_note is not None: |
|
|
remain_notes.append(best_note) |
|
|
cleansed_bass[key] = np.array(remain_notes) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocal = [ |
|
|
info['vocal_info'].get(str(i), []) if info['vocal_info'] is not None else [], |
|
|
info['vocal_info'].get(str(i+1), []) if info['vocal_info'] is not None else [], |
|
|
info['vocal_info'].get(str(i+2), []) if info['vocal_info'] is not None else [], |
|
|
info['vocal_info'].get(str(i+3), []) if info['vocal_info'] is not None else [] |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocal_pianoroll,vocal_CONLON_point = bar_notes_to_pianoroll(vocal, shift_val) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocal_pianorolls[idx] = vocal_pianoroll |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocal_CONLON_points[idx] = vocal_CONLON_point |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pianorolls['vocal'] = vocal_pianorolls |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CONLON_points['vocal'] = vocal_CONLON_points |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return pianorolls, start_points, CONLON_points |
|
|
|
|
|
|
|
|
|
|
|
def bar_notes_to_pianoroll(bars,shift_val): |
|
|
pianoroll = np.zeros((192,128)) |
|
|
conlon_points = [] |
|
|
for j, bar in enumerate(bars): |
|
|
j_offset = j * 48 |
|
|
for note in bar: |
|
|
start, pitch, end = int(note[4]), int(note[2]), int(note[5]) |
|
|
duration = (end - start + 1) |
|
|
start_idx = start + j_offset |
|
|
end_idx = end + j_offset + 1 |
|
|
conlon_points.append([start_idx, pitch, duration]) |
|
|
pianoroll[start_idx:end_idx, pitch] = 1 |
|
|
return pianoroll, conlon_points |
|
|
|
|
|
def infos_to_startpoint(info,use_all): |
|
|
downbeat_start = info['downbeat_start'] |
|
|
|
|
|
|
|
|
boundary = round((info['beat_times'][-1] -downbeat_start)/(4*(info['beat_times'][1]-info['beat_times'][0])))-1 |
|
|
|
|
|
song_structure_sp = [i for i in range(boundary+1)] |
|
|
song_structure_sp = refine_breakpoints_custom(song_structure_sp) |
|
|
if use_all: |
|
|
song_structure_sp = [i for i in range(song_structure_sp[-1])] |
|
|
return song_structure_sp |
|
|
|
|
|
def refine_breakpoints_custom(breakpoints, interval=4): |
|
|
refined = [] |
|
|
|
|
|
unique_breakpoints = [] |
|
|
for point in breakpoints: |
|
|
if point not in unique_breakpoints and point>0: |
|
|
unique_breakpoints.append(point) |
|
|
|
|
|
|
|
|
if len(unique_breakpoints)==0: |
|
|
unique_breakpoints.append(0) |
|
|
starting_point = unique_breakpoints[0] % interval |
|
|
if starting_point != unique_breakpoints[0]: |
|
|
for point in range(starting_point, unique_breakpoints[0], interval): |
|
|
if point > -1: |
|
|
refined.append(point) |
|
|
|
|
|
for i in range(len(unique_breakpoints)): |
|
|
|
|
|
refined.append(unique_breakpoints[i]) |
|
|
|
|
|
|
|
|
if i + 1 < len(unique_breakpoints): |
|
|
next_point = unique_breakpoints[i] |
|
|
while next_point + 2*interval <= unique_breakpoints[i + 1]: |
|
|
next_point += interval |
|
|
refined.append(next_point) |
|
|
if len(refined)==0: |
|
|
refined = [0] |
|
|
return refined |
|
|
|