Update modeling_mmMamba.py
Browse files- modeling_mmMamba.py +5 -7
modeling_mmMamba.py
CHANGED
|
@@ -24,22 +24,20 @@ import torch.nn.functional as F
|
|
| 24 |
import torch.utils.checkpoint
|
| 25 |
from einops import rearrange
|
| 26 |
from torch import nn
|
| 27 |
-
from torch.nn import
|
| 28 |
from transformers.activations import ACT2FN
|
| 29 |
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
| 30 |
-
CausalLMOutputWithPast
|
| 31 |
-
SequenceClassifierOutputWithPast)
|
| 32 |
from transformers.modeling_utils import PreTrainedModel
|
| 33 |
from transformers.utils import (add_start_docstrings,
|
| 34 |
add_start_docstrings_to_model_forward, logging,
|
| 35 |
replace_return_docstrings)
|
| 36 |
-
from
|
| 37 |
-
|
| 38 |
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
| 39 |
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 40 |
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
| 41 |
-
|
| 42 |
-
import time
|
| 43 |
|
| 44 |
try:
|
| 45 |
from transformers.generation.streamers import BaseStreamer
|
|
|
|
| 24 |
import torch.utils.checkpoint
|
| 25 |
from einops import rearrange
|
| 26 |
from torch import nn
|
| 27 |
+
from torch.nn import CrossEntropyLoss
|
| 28 |
from transformers.activations import ACT2FN
|
| 29 |
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
| 30 |
+
CausalLMOutputWithPast)
|
|
|
|
| 31 |
from transformers.modeling_utils import PreTrainedModel
|
| 32 |
from transformers.utils import (add_start_docstrings,
|
| 33 |
add_start_docstrings_to_model_forward, logging,
|
| 34 |
replace_return_docstrings)
|
| 35 |
+
from fused_norm_gate import FusedRMSNormSwishGate
|
| 36 |
+
|
| 37 |
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
| 38 |
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 39 |
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
| 40 |
+
|
|
|
|
| 41 |
|
| 42 |
try:
|
| 43 |
from transformers.generation.streamers import BaseStreamer
|