Spaces:
Runtime error
Runtime error
| from src.music.utilities.representation_learning_utilities.constants import * | |
| from src.music.config import REP_MODEL_NAME | |
| from src.music.utils import get_out_path | |
| import pickle | |
| import numpy as np | |
| # from transformers import AutoModel, AutoTokenizer | |
| from torch import nn | |
| from src.music.representation_learning.sentence_transfo.sentence_transformers import SentenceTransformer | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| class Argument(object): | |
| def __init__(self, adict): | |
| self.__dict__.update(adict) | |
| class RepModel(nn.Module): | |
| def __init__(self, model, model_name): | |
| super().__init__() | |
| if 't5' in model_name: | |
| self.model = model.get_encoder() | |
| else: | |
| self.model = model | |
| self.model.eval() | |
| def forward(self, inputs): | |
| with torch.no_grad(): | |
| out = self.model(inputs, output_hidden_states=True) | |
| embeddings = out.hidden_states[-1] | |
| return torch.mean(embeddings[0], dim=0) | |
| # def get_trained_music_LM(model_name): | |
| # tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True) | |
| # model = RepModel(AutoModel.from_pretrained(model_name, use_auth_token=True), model_name) | |
| # | |
| # return model, tokenizer | |
| def get_trained_sentence_embedder(model_name): | |
| model = SentenceTransformer(model_name, device=device) | |
| return model | |
| MODEL = get_trained_sentence_embedder(REP_MODEL_NAME) | |
| def encoded2rep(encoded_path, rep_path=None, return_rep=False, verbose=False, level=0): | |
| if not rep_path: | |
| rep_path, _, _ = get_out_path(in_path=encoded_path, in_word='encoded', out_word='represented', out_extension='.txt') | |
| error_msg = 'Error in music transformer mapping.' | |
| if verbose: print(' ' * level + 'Mapping to final music representations') | |
| # try: | |
| error_msg += ' Error in encoded file loading?' | |
| with open(encoded_path, 'rb') as f: | |
| data = pickle.load(f) | |
| performance = [str(w) for w in data['main'] if w != 1] | |
| assert len(performance) % 5 == 0 | |
| if(len(performance) == 0): | |
| error_msg += " Error: No midi messages in primer file" | |
| assert False | |
| error_msg += ' Nope, error in tokenization?' | |
| perf = ' '.join(performance) | |
| # tokenized = torch.IntTensor(TOKENIZER.encode(perf)).unsqueeze(dim=0) | |
| error_msg += ' Nope. Maybe in performance encoding?' | |
| # reps = [] | |
| # for i_chunk in range(min(tokenized.shape[1] // 510 - 1, 8)): | |
| # chunk_tokenized = tokenized[:, i_chunk * 510: (i_chunk + 1) * 510 + 2] | |
| # rep = MODEL(chunk_tokenized) | |
| # reps.append(rep.detach().numpy()) | |
| # representation = np.mean(reps, axis=0) | |
| p = [int(p) for p in perf.split(' ')] | |
| # print('PERF:', np.sum(p), perf) | |
| representation = MODEL.encode(perf) | |
| # print('model weights sum: ', np.sum([param.detach().data.numpy().sum() for param in list(MODEL.parameters())])) | |
| # print('reprep', representation) | |
| error_msg += ' Nope. Saving performance?' | |
| np.savetxt(rep_path, representation) | |
| error_msg += ' Nope.' | |
| if verbose: print(' ' * (level + 2) + 'Success.') | |
| if return_rep: | |
| return rep_path, representation, '' | |
| else: | |
| return rep_path, '' | |
| #except: | |
| # if verbose: print(' ' * (level + 2) + f'Failed with error: {error_msg}') | |
| # if return_rep: | |
| # return None, None, error_msg | |
| #else: | |
| # return None, error_msg | |
| if __name__ == "__main__": | |
| representation = encoded2rep("/home/cedric/Documents/pianocktail/data/music/encoded/single_videos_midi_processed_encoded/chris_dawson_all_of_me_.pickle") | |
| stop = 1 | |