hyx21 commited on
Commit
eaade69
·
verified ·
1 Parent(s): bb6b7d4

Update modeling_llama_long_infllmv2.py

Browse files
Files changed (1) hide show
  1. modeling_llama_long_infllmv2.py +0 -3
modeling_llama_long_infllmv2.py CHANGED
@@ -50,10 +50,8 @@ from transformers.utils import (
50
  replace_return_docstrings,
51
  )
52
  from transformers.models.llama.configuration_llama import LlamaConfig
53
- from moba import moba_attn_varlen
54
  from functools import lru_cache
55
  from .cis_pooling import nosa_mean_pooling
56
- from native_sparse_attention.ops.triton.topk_sparse_attention import topk_sparse_attention
57
  from tqdm import tqdm
58
  import torch.cuda.nvtx as nvtx
59
 
@@ -683,7 +681,6 @@ class LlamaFlashAttention2(LlamaAttention):
683
 
684
  return attn_output, attn_weights, past_key_value
685
 
686
- from native_sparse_attention.ops.triton.topk_sparse_attention import topk_sparse_attention
687
  try:
688
  from flash_attn import flash_attn_func, flash_attn_varlen_func
689
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
50
  replace_return_docstrings,
51
  )
52
  from transformers.models.llama.configuration_llama import LlamaConfig
 
53
  from functools import lru_cache
54
  from .cis_pooling import nosa_mean_pooling
 
55
  from tqdm import tqdm
56
  import torch.cuda.nvtx as nvtx
57
 
 
681
 
682
  return attn_output, attn_weights, past_key_value
683
 
 
684
  try:
685
  from flash_attn import flash_attn_func, flash_attn_varlen_func
686
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa