fix pipeline
Browse files- test_ax_model.py +10 -2
test_ax_model.py
CHANGED
|
@@ -228,8 +228,8 @@ class FireRedASROnnxModel:
|
|
| 228 |
# self.decoder_main.get_inputs()[1].name: n_layer_self_k_cache,
|
| 229 |
"n_layer_cross_k": n_layer_cross_k_cache,
|
| 230 |
"n_layer_cross_v": n_layer_cross_v_cache,
|
| 231 |
-
"pe": pe,
|
| 232 |
-
"self_attn_mask": self_attn_mask,
|
| 233 |
"cross_attn_mask": cross_attn_mask,
|
| 234 |
# self.decoder_main.get_inputs()[7].name: cross_attn_mask,
|
| 235 |
}
|
|
@@ -356,6 +356,9 @@ class FireRedASROnnxModel:
|
|
| 356 |
n_layer_cross_v = to_numpy(n_layer_cross_v)
|
| 357 |
cross_attn_mask = to_numpy(cross_attn_mask)
|
| 358 |
|
|
|
|
|
|
|
|
|
|
| 359 |
# for name, npy in zip(
|
| 360 |
# ["tokens", "n_layer_self_k_cache", "n_layer_self_v_cache", "n_layer_cross_k", "n_layer_cross_v", "pe", "self_attn_mask", "cross_attn_mask"],
|
| 361 |
# [tokens, n_layer_self_k_cache, n_layer_self_v_cache, n_layer_cross_k, n_layer_cross_v, self.pe[offset], self_attn_mask, cross_attn_mask]
|
|
@@ -365,6 +368,7 @@ class FireRedASROnnxModel:
|
|
| 365 |
# np.save(os.path.join(file_path, f"{i}.npy"), npy)
|
| 366 |
|
| 367 |
if i == 0:
|
|
|
|
| 368 |
logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_main_one_token(
|
| 369 |
to_numpy(tokens),
|
| 370 |
to_numpy(n_layer_self_k_cache),
|
|
@@ -375,7 +379,9 @@ class FireRedASROnnxModel:
|
|
| 375 |
self_attn_mask,
|
| 376 |
to_numpy(cross_attn_mask)
|
| 377 |
)
|
|
|
|
| 378 |
else:
|
|
|
|
| 379 |
logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
|
| 380 |
to_numpy(tokens),
|
| 381 |
to_numpy(n_layer_self_k_cache),
|
|
@@ -386,6 +392,7 @@ class FireRedASROnnxModel:
|
|
| 386 |
self_attn_mask,
|
| 387 |
to_numpy(cross_attn_mask)
|
| 388 |
)
|
|
|
|
| 389 |
|
| 390 |
offset += 1
|
| 391 |
logits = torch.from_numpy(logits)
|
|
@@ -513,6 +520,7 @@ class FireRedASROnnxModel:
|
|
| 513 |
to_numpy(feats),
|
| 514 |
to_numpy(lengths)
|
| 515 |
)
|
|
|
|
| 516 |
nbest_hyps = self.run_decoder(n_layer_cross_k,
|
| 517 |
n_layer_cross_v,
|
| 518 |
cross_attn_mask,
|
|
|
|
| 228 |
# self.decoder_main.get_inputs()[1].name: n_layer_self_k_cache,
|
| 229 |
"n_layer_cross_k": n_layer_cross_k_cache,
|
| 230 |
"n_layer_cross_v": n_layer_cross_v_cache,
|
| 231 |
+
# "pe": pe,
|
| 232 |
+
# "self_attn_mask": self_attn_mask,
|
| 233 |
"cross_attn_mask": cross_attn_mask,
|
| 234 |
# self.decoder_main.get_inputs()[7].name: cross_attn_mask,
|
| 235 |
}
|
|
|
|
| 356 |
n_layer_cross_v = to_numpy(n_layer_cross_v)
|
| 357 |
cross_attn_mask = to_numpy(cross_attn_mask)
|
| 358 |
|
| 359 |
+
self_attn_mask = np.zeros((batch_size * beam_size, 1, self.decode_max_len), dtype=np.float32)
|
| 360 |
+
self_attn_mask[:, :, :self.decode_max_len - offset[0] - 1] = -np.inf
|
| 361 |
+
|
| 362 |
# for name, npy in zip(
|
| 363 |
# ["tokens", "n_layer_self_k_cache", "n_layer_self_v_cache", "n_layer_cross_k", "n_layer_cross_v", "pe", "self_attn_mask", "cross_attn_mask"],
|
| 364 |
# [tokens, n_layer_self_k_cache, n_layer_self_v_cache, n_layer_cross_k, n_layer_cross_v, self.pe[offset], self_attn_mask, cross_attn_mask]
|
|
|
|
| 368 |
# np.save(os.path.join(file_path, f"{i}.npy"), npy)
|
| 369 |
|
| 370 |
if i == 0:
|
| 371 |
+
start_time = time.time()
|
| 372 |
logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_main_one_token(
|
| 373 |
to_numpy(tokens),
|
| 374 |
to_numpy(n_layer_self_k_cache),
|
|
|
|
| 379 |
self_attn_mask,
|
| 380 |
to_numpy(cross_attn_mask)
|
| 381 |
)
|
| 382 |
+
print(f"run decoder_main take {(time.time() - start_time) * 1000}ms")
|
| 383 |
else:
|
| 384 |
+
start_time = time.time()
|
| 385 |
logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
|
| 386 |
to_numpy(tokens),
|
| 387 |
to_numpy(n_layer_self_k_cache),
|
|
|
|
| 392 |
self_attn_mask,
|
| 393 |
to_numpy(cross_attn_mask)
|
| 394 |
)
|
| 395 |
+
print(f"run decoder_loop take {(time.time() - start_time) * 1000}ms")
|
| 396 |
|
| 397 |
offset += 1
|
| 398 |
logits = torch.from_numpy(logits)
|
|
|
|
| 520 |
to_numpy(feats),
|
| 521 |
to_numpy(lengths)
|
| 522 |
)
|
| 523 |
+
print(f"run encoder take {(time.time() - start_time) * 1000}ms")
|
| 524 |
nbest_hyps = self.run_decoder(n_layer_cross_k,
|
| 525 |
n_layer_cross_v,
|
| 526 |
cross_attn_mask,
|