inoryQwQ commited on
Commit
40049a2
·
1 Parent(s): 136adf2

fix pipeline

Browse files
Files changed (1) hide show
  1. test_ax_model.py +10 -2
test_ax_model.py CHANGED
@@ -228,8 +228,8 @@ class FireRedASROnnxModel:
228
  # self.decoder_main.get_inputs()[1].name: n_layer_self_k_cache,
229
  "n_layer_cross_k": n_layer_cross_k_cache,
230
  "n_layer_cross_v": n_layer_cross_v_cache,
231
- "pe": pe,
232
- "self_attn_mask": self_attn_mask,
233
  "cross_attn_mask": cross_attn_mask,
234
  # self.decoder_main.get_inputs()[7].name: cross_attn_mask,
235
  }
@@ -356,6 +356,9 @@ class FireRedASROnnxModel:
356
  n_layer_cross_v = to_numpy(n_layer_cross_v)
357
  cross_attn_mask = to_numpy(cross_attn_mask)
358
 
 
 
 
359
  # for name, npy in zip(
360
  # ["tokens", "n_layer_self_k_cache", "n_layer_self_v_cache", "n_layer_cross_k", "n_layer_cross_v", "pe", "self_attn_mask", "cross_attn_mask"],
361
  # [tokens, n_layer_self_k_cache, n_layer_self_v_cache, n_layer_cross_k, n_layer_cross_v, self.pe[offset], self_attn_mask, cross_attn_mask]
@@ -365,6 +368,7 @@ class FireRedASROnnxModel:
365
  # np.save(os.path.join(file_path, f"{i}.npy"), npy)
366
 
367
  if i == 0:
 
368
  logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_main_one_token(
369
  to_numpy(tokens),
370
  to_numpy(n_layer_self_k_cache),
@@ -375,7 +379,9 @@ class FireRedASROnnxModel:
375
  self_attn_mask,
376
  to_numpy(cross_attn_mask)
377
  )
 
378
  else:
 
379
  logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
380
  to_numpy(tokens),
381
  to_numpy(n_layer_self_k_cache),
@@ -386,6 +392,7 @@ class FireRedASROnnxModel:
386
  self_attn_mask,
387
  to_numpy(cross_attn_mask)
388
  )
 
389
 
390
  offset += 1
391
  logits = torch.from_numpy(logits)
@@ -513,6 +520,7 @@ class FireRedASROnnxModel:
513
  to_numpy(feats),
514
  to_numpy(lengths)
515
  )
 
516
  nbest_hyps = self.run_decoder(n_layer_cross_k,
517
  n_layer_cross_v,
518
  cross_attn_mask,
 
228
  # self.decoder_main.get_inputs()[1].name: n_layer_self_k_cache,
229
  "n_layer_cross_k": n_layer_cross_k_cache,
230
  "n_layer_cross_v": n_layer_cross_v_cache,
231
+ # "pe": pe,
232
+ # "self_attn_mask": self_attn_mask,
233
  "cross_attn_mask": cross_attn_mask,
234
  # self.decoder_main.get_inputs()[7].name: cross_attn_mask,
235
  }
 
356
  n_layer_cross_v = to_numpy(n_layer_cross_v)
357
  cross_attn_mask = to_numpy(cross_attn_mask)
358
 
359
+ self_attn_mask = np.zeros((batch_size * beam_size, 1, self.decode_max_len), dtype=np.float32)
360
+ self_attn_mask[:, :, :self.decode_max_len - offset[0] - 1] = -np.inf
361
+
362
  # for name, npy in zip(
363
  # ["tokens", "n_layer_self_k_cache", "n_layer_self_v_cache", "n_layer_cross_k", "n_layer_cross_v", "pe", "self_attn_mask", "cross_attn_mask"],
364
  # [tokens, n_layer_self_k_cache, n_layer_self_v_cache, n_layer_cross_k, n_layer_cross_v, self.pe[offset], self_attn_mask, cross_attn_mask]
 
368
  # np.save(os.path.join(file_path, f"{i}.npy"), npy)
369
 
370
  if i == 0:
371
+ start_time = time.time()
372
  logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_main_one_token(
373
  to_numpy(tokens),
374
  to_numpy(n_layer_self_k_cache),
 
379
  self_attn_mask,
380
  to_numpy(cross_attn_mask)
381
  )
382
+ print(f"run decoder_main take {(time.time() - start_time) * 1000}ms")
383
  else:
384
+ start_time = time.time()
385
  logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
386
  to_numpy(tokens),
387
  to_numpy(n_layer_self_k_cache),
 
392
  self_attn_mask,
393
  to_numpy(cross_attn_mask)
394
  )
395
+ print(f"run decoder_loop take {(time.time() - start_time) * 1000}ms")
396
 
397
  offset += 1
398
  logits = torch.from_numpy(logits)
 
520
  to_numpy(feats),
521
  to_numpy(lengths)
522
  )
523
+ print(f"run encoder take {(time.time() - start_time) * 1000}ms")
524
  nbest_hyps = self.run_decoder(n_layer_cross_k,
525
  n_layer_cross_v,
526
  cross_attn_mask,