Commit
·
7bfacad
1
Parent(s):
ae3131d
Prep for ALICE training
Browse files- Maestro/.DS_Store +0 -0
- train.py +1 -44
Maestro/.DS_Store
CHANGED
|
Binary files a/Maestro/.DS_Store and b/Maestro/.DS_Store differ
|
|
|
train.py
CHANGED
|
@@ -182,47 +182,4 @@ trainer.save_model() # Saves the tokenizer too
|
|
| 182 |
trainer.log_metrics("train", train_result.metrics)
|
| 183 |
trainer.save_metrics("train", train_result.metrics)
|
| 184 |
trainer.save_state()
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)
|
| 188 |
-
generation_config = GenerationConfig(
|
| 189 |
-
max_new_tokens=512, # extends samples by 512 tokens
|
| 190 |
-
num_beams=1, # no beam search
|
| 191 |
-
do_sample=True, # but sample instead
|
| 192 |
-
temperature=0.9,
|
| 193 |
-
top_k=15,
|
| 194 |
-
top_p=0.95,
|
| 195 |
-
epsilon_cutoff=3e-4,
|
| 196 |
-
eta_cutoff=1e-3,
|
| 197 |
-
pad_token_id=config.padding_token_id,
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
# Here the sequences are padded to the left, so that the last token along the time dimension
|
| 201 |
-
# is always the last token of each seq, allowing to efficiently generate by batch
|
| 202 |
-
collator.pad_on_left = True
|
| 203 |
-
collator.eos_token = None
|
| 204 |
-
dataloader_test = DataLoader(dataset_test, batch_size=16, collate_fn=collator)
|
| 205 |
-
model.eval()
|
| 206 |
-
count = 0
|
| 207 |
-
for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): # (N,T)
|
| 208 |
-
res = model.generate(
|
| 209 |
-
inputs=batch["input_ids"].to(model.device),
|
| 210 |
-
attention_mask=batch["attention_mask"].to(model.device),
|
| 211 |
-
generation_config=generation_config) # (N,T)
|
| 212 |
-
|
| 213 |
-
# Saves the generated music, as MIDI files and tokens (json)
|
| 214 |
-
for prompt, continuation in zip(batch["input_ids"], res):
|
| 215 |
-
generated = continuation[len(prompt):]
|
| 216 |
-
midi = tokenizer.tokens_to_midi([deepcopy(generated.tolist())])
|
| 217 |
-
tokens = [generated, prompt, continuation] # list compr. as seqs of dif. lengths
|
| 218 |
-
tokens = [seq.tolist() for seq in tokens]
|
| 219 |
-
for tok_seq in tokens[1:]:
|
| 220 |
-
_midi = tokenizer.tokens_to_midi([deepcopy(tok_seq)])
|
| 221 |
-
midi.instruments.append(_midi.instruments[0])
|
| 222 |
-
midi.instruments[0].name = f'Continuation of original sample ({len(generated)} tokens)'
|
| 223 |
-
midi.instruments[1].name = f'Original sample ({len(prompt)} tokens)'
|
| 224 |
-
midi.instruments[2].name = f'Original sample and continuation'
|
| 225 |
-
midi.dump_midi(gen_results_path / f'{count}.mid')
|
| 226 |
-
tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json')
|
| 227 |
-
|
| 228 |
-
count += 1
|
|
|
|
| 182 |
trainer.log_metrics("train", train_result.metrics)
|
| 183 |
trainer.save_metrics("train", train_result.metrics)
|
| 184 |
trainer.save_state()
|
| 185 |
+
trainer.push_to_hub()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|