Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import torch | |
| import random | |
| import argparse | |
| from unidecode import unidecode | |
| from samplings import top_p_sampling, temperature_sampling | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| def generate_abc(args): | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| print('There are %d GPU(s) available.' % torch.cuda.device_count()) | |
| print('We will use the GPU:', torch.cuda.get_device_name(0), '\n') | |
| else: | |
| print('No GPU available, using the CPU instead.\n') | |
| device = torch.device("cpu") | |
| num_tunes = args.num_tunes | |
| max_length = args.max_length | |
| top_p = args.top_p | |
| temperature = args.temperature | |
| seed = args.seed | |
| print(" HYPERPARAMETERS ".center(60, "#"), '\n') | |
| args = vars(args) | |
| for key in args.keys(): | |
| print(key+': '+str(args[key])) | |
| with open('input_text.txt') as f: | |
| text = unidecode(f.read()) | |
| print("\n"+" INPUT TEXT ".center(60, "#")) | |
| print('\n'+text+'\n') | |
| tokenizer = AutoTokenizer.from_pretrained('sander-wood/text-to-music') | |
| model = AutoModelForSeq2SeqLM.from_pretrained('sander-wood/text-to-music') | |
| model = model.to(device) | |
| input_ids = tokenizer(text, | |
| return_tensors='pt', | |
| truncation=True, | |
| max_length=max_length)['input_ids'].to(device) | |
| decoder_start_token_id = model.config.decoder_start_token_id | |
| eos_token_id = model.config.eos_token_id | |
| random.seed(seed) | |
| tunes = "" | |
| print(" OUTPUT TUNES ".center(60, "#")) | |
| for n_idx in range(num_tunes): | |
| print("\nX:"+str(n_idx+1)+"\n", end="") | |
| tunes += "X:"+str(n_idx+1)+"\n" | |
| decoder_input_ids = torch.tensor([[decoder_start_token_id]]) | |
| for t_idx in range(max_length): | |
| if seed!=None: | |
| n_seed = random.randint(0, 1000000) | |
| random.seed(n_seed) | |
| else: | |
| n_seed = None | |
| outputs = model(input_ids=input_ids, | |
| decoder_input_ids=decoder_input_ids.to(device)) | |
| probs = outputs.logits[0][-1] | |
| probs = torch.nn.Softmax(dim=-1)(probs).cpu().detach().numpy() | |
| sampled_id = temperature_sampling(probs=top_p_sampling(probs, | |
| top_p=top_p, | |
| seed=n_seed, | |
| return_probs=True), | |
| seed=n_seed, | |
| temperature=temperature) | |
| decoder_input_ids = torch.cat((decoder_input_ids, torch.tensor([[sampled_id]])), 1) | |
| if sampled_id!=eos_token_id: | |
| sampled_token = tokenizer.decode([sampled_id]) | |
| print(sampled_token, end="") | |
| tunes += sampled_token | |
| else: | |
| tunes += '\n' | |
| break | |
| timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime()) | |
| with open('output_tunes/'+timestamp+'.abc', 'w') as f: | |
| f.write(unidecode(tunes)) | |
| def get_args(parser): | |
| parser.add_argument('-num_tunes', type=int, default=3, help='the number of independently computed returned tunes') | |
| parser.add_argument('-max_length', type=int, default=1024, help='integer to define the maximum length in tokens of each tune') | |
| parser.add_argument('-top_p', type=float, default=0.9, help='float to define the tokens that are within the sample operation of text generation') | |
| parser.add_argument('-temperature', type=float, default=1., help='the temperature of the sampling operation') | |
| parser.add_argument('-seed', type=int, default=None, help='seed for randomstate') | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| args = get_args(parser) | |
| generate_abc(args) |