File size: 1,806 Bytes
fe0450e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# -*- coding: utf-8 -*-

# from fla.layers import (ABCAttention, Attention, BasedLinearAttention,
#                         DeltaNet, GatedLinearAttention, HGRN2Attention,
#                         LinearAttention, MultiScaleRetention,
#                         ReBasedLinearAttention)
# from fla.models import (ABCForCausalLM, ABCModel, DeltaNetForCausalLM,
#                         DeltaNetModel, GLAForCausalLM, GLAModel,
#                         HGRN2ForCausalLM, HGRN2Model, HGRNForCausalLM,
#                         HGRNModel, LinearAttentionForCausalLM,
#                         LinearAttentionModel, RetNetForCausalLM, RetNetModel,
#                         RWKV6ForCausalLM, RWKV6Model, TransformerForCausalLM,
#                         TransformerModel)
# from fla.ops import (chunk_gla, chunk_retention, fused_chunk_based,
#                      fused_chunk_gla, fused_chunk_retention)
from .models import emla,emgla,mask_deltanet,mask_gdn,transformer

# __all__ = [
#     'ABCAttention',
#     'Attention',
#     'BasedLinearAttention',
#     'DeltaNet',
#     'HGRN2Attention',
#     'GatedLinearAttention',
#     'LinearAttention',
#     'MultiScaleRetention',
#     'ReBasedLinearAttention',
#     'ABCForCausalLM',
#     'ABCModel',
#     'DeltaNetForCausalLM',
#     'DeltaNetModel',
#     'HGRNForCausalLM',
#     'HGRNModel',
#     'HGRN2ForCausalLM',
#     'HGRN2Model',
#     'GLAForCausalLM',
#     'GLAModel',
#     'LinearAttentionForCausalLM',
#     'LinearAttentionModel',
#     'RetNetForCausalLM',
#     'RetNetModel',
#     'RWKV6ForCausalLM',
#     'RWKV6Model',
#     'TransformerForCausalLM',
#     'TransformerModel',
#     'chunk_gla',
#     'chunk_retention',
#     'fused_chunk_based',
#     'fused_chunk_gla',
#     'fused_chunk_retention'
# ]

__version__ = '0.1'