| ForgettingTransformerForCausalLM( | |
| (model): ForgettingTransformerModel( | |
| (embeddings): Embedding(50277, 256) | |
| (layers): ModuleList( | |
| (0-2): 3 x ForgettingTransformerBlock( | |
| (attn_norm): RMSNorm(256, eps=1e-06) | |
| (attn): ForgettingAttentionLayer( | |
| (q_proj): Linear(in_features=256, out_features=256, bias=False) | |
| (k_proj): Linear(in_features=256, out_features=256, bias=False) | |
| (v_proj): Linear(in_features=256, out_features=256, bias=False) | |
| (o_proj): Linear(in_features=256, out_features=256, bias=False) | |
| (fgate_proj): Linear(in_features=256, out_features=4, bias=True) | |
| ) | |
| (mlp_norm): RMSNorm(256, eps=1e-06) | |
| (mlp): ForgettingTransformerMLP( | |
| (gate_proj): Linear(in_features=256, out_features=1536, bias=False) | |
| (down_proj): Linear(in_features=768, out_features=256, bias=False) | |
| (act_fn): SiLU() | |
| ) | |
| ) | |
| ) | |
| (norm): RMSNorm(256, eps=1e-06) | |
| ) | |
| (lm_head): Linear(in_features=256, out_features=50277, bias=False) | |
| ) | |