Spaces:
Runtime error
Runtime error
Upload app_onnx.py
Browse files- app_onnx.py +16 -26
app_onnx.py
CHANGED
|
@@ -29,6 +29,7 @@ def softmax(x, axis):
|
|
| 29 |
exp_x_shifted = np.exp(x - x_max)
|
| 30 |
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
|
| 31 |
|
|
|
|
| 32 |
def sample_top_p_k(probs, p, k, generator=None):
|
| 33 |
if generator is None:
|
| 34 |
generator = np.random
|
|
@@ -48,9 +49,10 @@ def sample_top_p_k(probs, p, k, generator=None):
|
|
| 48 |
next_token = next_token.reshape(*shape[:-1])
|
| 49 |
return next_token
|
| 50 |
|
|
|
|
| 51 |
def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, past_len, cur_len):
|
| 52 |
io_binding = model.io_binding()
|
| 53 |
-
for input_ in
|
| 54 |
name = input_.name
|
| 55 |
if name.startswith("past_key_values"):
|
| 56 |
present_name = name.replace("past_key_values", "present")
|
|
@@ -80,8 +82,7 @@ def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, pa
|
|
| 80 |
return io_binding
|
| 81 |
|
| 82 |
def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
| 83 |
-
disable_patch_change=False, disable_control_change=False, disable_channels=None,
|
| 84 |
-
repetition_penalty=1.0, generator=None):
|
| 85 |
tokenizer = model[2]
|
| 86 |
if disable_channels is not None:
|
| 87 |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
|
@@ -106,7 +107,7 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
|
|
| 106 |
prompt = prompt[..., :max_token_seq]
|
| 107 |
if prompt.shape[-1] < max_token_seq:
|
| 108 |
prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
|
| 109 |
-
|
| 110 |
input_tensor = prompt
|
| 111 |
cur_len = input_tensor.shape[1]
|
| 112 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
|
|
@@ -161,6 +162,7 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
|
|
| 161 |
mask = mask[:, None, :]
|
| 162 |
x = next_token_seq
|
| 163 |
if i != 0:
|
|
|
|
| 164 |
if i == 1:
|
| 165 |
hidden = np.zeros((batch_size, 0, emb_size), dtype=np.float32)
|
| 166 |
model1_inputs["hidden"] = rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)
|
|
@@ -176,16 +178,6 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
|
|
| 176 |
model[1].run_with_iobinding(io_binding)
|
| 177 |
io_binding.synchronize_outputs()
|
| 178 |
logits = model1_outputs["y"].numpy()
|
| 179 |
-
|
| 180 |
-
# Apply repetition penalty
|
| 181 |
-
if repetition_penalty != 1.0:
|
| 182 |
-
for b in range(batch_size):
|
| 183 |
-
if not end[b]:
|
| 184 |
-
prev_tokens = input_tensor[b, :cur_len].tolist()
|
| 185 |
-
used_tokens = set(prev_tokens)
|
| 186 |
-
for token in used_tokens:
|
| 187 |
-
logits[b, :, token] /= repetition_penalty
|
| 188 |
-
|
| 189 |
scores = softmax(logits / temp, -1) * mask
|
| 190 |
samples = sample_top_p_k(scores, top_p, top_k, generator)
|
| 191 |
if i == 0:
|
|
@@ -204,8 +196,8 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
|
|
| 204 |
break
|
| 205 |
if next_token_seq.shape[1] < max_token_seq:
|
| 206 |
next_token_seq = np.pad(next_token_seq,
|
| 207 |
-
|
| 208 |
-
|
| 209 |
next_token_seq = next_token_seq[:, None, :]
|
| 210 |
input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
|
| 211 |
past_len = cur_len
|
|
@@ -594,12 +586,10 @@ if __name__ == "__main__":
|
|
| 594 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
| 595 |
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.95)
|
| 596 |
input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
|
| 597 |
-
input_rep_penalty = gr.Slider(label="repetition penalty", minimum=1.0, maximum=2.0,
|
| 598 |
-
step=0.05, value=1.0)
|
| 599 |
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
| 600 |
input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
|
| 601 |
example3 = gr.Examples([[1, 0.94, 128], [1, 0.98, 20], [1, 0.98, 12]],
|
| 602 |
-
|
| 603 |
run_btn = gr.Button("generate", variant="primary")
|
| 604 |
# stop_btn = gr.Button("stop and output")
|
| 605 |
output_midi_seq = gr.State()
|
|
@@ -615,13 +605,13 @@ if __name__ == "__main__":
|
|
| 615 |
midi_outputs.append(output_midi)
|
| 616 |
audio_outputs.append(output_audio)
|
| 617 |
run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
finish_run_event = run_event.then(fn=finish_run,
|
| 626 |
inputs=[input_model, output_midi_seq],
|
| 627 |
outputs=midi_outputs + [js_msg],
|
|
|
|
| 29 |
exp_x_shifted = np.exp(x - x_max)
|
| 30 |
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
|
| 31 |
|
| 32 |
+
|
| 33 |
def sample_top_p_k(probs, p, k, generator=None):
|
| 34 |
if generator is None:
|
| 35 |
generator = np.random
|
|
|
|
| 49 |
next_token = next_token.reshape(*shape[:-1])
|
| 50 |
return next_token
|
| 51 |
|
| 52 |
+
|
| 53 |
def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, past_len, cur_len):
|
| 54 |
io_binding = model.io_binding()
|
| 55 |
+
for input_ in model.get_inputs():
|
| 56 |
name = input_.name
|
| 57 |
if name.startswith("past_key_values"):
|
| 58 |
present_name = name.replace("past_key_values", "present")
|
|
|
|
| 82 |
return io_binding
|
| 83 |
|
| 84 |
def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
| 85 |
+
disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
|
|
|
|
| 86 |
tokenizer = model[2]
|
| 87 |
if disable_channels is not None:
|
| 88 |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
|
|
|
| 107 |
prompt = prompt[..., :max_token_seq]
|
| 108 |
if prompt.shape[-1] < max_token_seq:
|
| 109 |
prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
|
| 110 |
+
mode="constant", constant_values=tokenizer.pad_id)
|
| 111 |
input_tensor = prompt
|
| 112 |
cur_len = input_tensor.shape[1]
|
| 113 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
|
|
|
|
| 162 |
mask = mask[:, None, :]
|
| 163 |
x = next_token_seq
|
| 164 |
if i != 0:
|
| 165 |
+
# cached
|
| 166 |
if i == 1:
|
| 167 |
hidden = np.zeros((batch_size, 0, emb_size), dtype=np.float32)
|
| 168 |
model1_inputs["hidden"] = rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)
|
|
|
|
| 178 |
model[1].run_with_iobinding(io_binding)
|
| 179 |
io_binding.synchronize_outputs()
|
| 180 |
logits = model1_outputs["y"].numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
scores = softmax(logits / temp, -1) * mask
|
| 182 |
samples = sample_top_p_k(scores, top_p, top_k, generator)
|
| 183 |
if i == 0:
|
|
|
|
| 196 |
break
|
| 197 |
if next_token_seq.shape[1] < max_token_seq:
|
| 198 |
next_token_seq = np.pad(next_token_seq,
|
| 199 |
+
((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
|
| 200 |
+
mode="constant", constant_values=tokenizer.pad_id)
|
| 201 |
next_token_seq = next_token_seq[:, None, :]
|
| 202 |
input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
|
| 203 |
past_len = cur_len
|
|
|
|
| 586 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
| 587 |
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.95)
|
| 588 |
input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
|
|
|
|
|
|
|
| 589 |
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
| 590 |
input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
|
| 591 |
example3 = gr.Examples([[1, 0.94, 128], [1, 0.98, 20], [1, 0.98, 12]],
|
| 592 |
+
[input_temp, input_top_p, input_top_k])
|
| 593 |
run_btn = gr.Button("generate", variant="primary")
|
| 594 |
# stop_btn = gr.Button("stop and output")
|
| 595 |
output_midi_seq = gr.State()
|
|
|
|
| 605 |
midi_outputs.append(output_midi)
|
| 606 |
audio_outputs.append(output_audio)
|
| 607 |
run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
|
| 608 |
+
input_continuation_select, input_instruments, input_drum_kit, input_bpm,
|
| 609 |
+
input_time_sig, input_key_sig, input_midi, input_midi_events,
|
| 610 |
+
input_reduce_cc_st, input_remap_track_channel,
|
| 611 |
+
input_add_default_instr, input_remove_empty_channels,
|
| 612 |
+
input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
|
| 613 |
+
input_top_k, input_allow_cc],
|
| 614 |
+
[output_midi_seq, output_continuation_state, input_seed, js_msg], queue=True)
|
| 615 |
finish_run_event = run_event.then(fn=finish_run,
|
| 616 |
inputs=[input_model, output_midi_seq],
|
| 617 |
outputs=midi_outputs + [js_msg],
|