Upload folder using huggingface_hub
Browse files- __init__.py +9 -0
- __pycache__/__init__.cpython-312.pyc +0 -0
- __pycache__/modeling_my_model.cpython-312.pyc +0 -0
- config.json +28 -0
- configuration_my_model.py +46 -0
- merges.txt +0 -0
- modeling_my_model.py +908 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +5 -0
- tokenizer.json +0 -0
- tokenizer_config.json +20 -0
- vocab.json +0 -0
__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from .modeling_my_model import GPT, GPTConfig
|
| 2 |
+
|
| 3 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 4 |
+
|
| 5 |
+
from .configuration_my_model import GPTConfig
|
| 6 |
+
from .modeling_my_model import GPT
|
| 7 |
+
|
| 8 |
+
AutoConfig.register("custom_gpt", GPTConfig)
|
| 9 |
+
AutoModelForCausalLM.register(GPTConfig, GPT)
|
__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (253 Bytes). View file
|
|
|
__pycache__/modeling_my_model.cpython-312.pyc
ADDED
|
Binary file (34.3 kB). View file
|
|
|
config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "custom_gpt",
|
| 3 |
+
"architectures": ["GPT"],
|
| 4 |
+
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "modeling_my_model.GPTConfig",
|
| 7 |
+
"AutoModelForCausalLM": "modeling_my_model.GPT"
|
| 8 |
+
},
|
| 9 |
+
|
| 10 |
+
"block_size": 1024,
|
| 11 |
+
"vocab_size": 50304,
|
| 12 |
+
"n_layer": 24,
|
| 13 |
+
"n_head": 16,
|
| 14 |
+
"n_embd": 1024,
|
| 15 |
+
"dropout": 0.0,
|
| 16 |
+
"bias": false,
|
| 17 |
+
|
| 18 |
+
"hc_num_streams": 1,
|
| 19 |
+
"hc_num_fracs": 1,
|
| 20 |
+
"hc_disable": true,
|
| 21 |
+
"mhc": false,
|
| 22 |
+
"sinkhorn_iters": 10,
|
| 23 |
+
"sinkhorn_tau": 0.05,
|
| 24 |
+
"mhc_h_res_proj": "sinkhorn",
|
| 25 |
+
"ns_steps": 5,
|
| 26 |
+
"ns_eps": 1e-7,
|
| 27 |
+
"ns_coeffs": [3.0, -3.2, 1.2]
|
| 28 |
+
}
|
configuration_my_model.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class GPTConfig(PretrainedConfig):
|
| 4 |
+
model_type = "custom_gpt"
|
| 5 |
+
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
block_size=1024,
|
| 9 |
+
vocab_size=50304,
|
| 10 |
+
n_layer=12,
|
| 11 |
+
n_head=12,
|
| 12 |
+
n_embd=768,
|
| 13 |
+
dropout=0.0,
|
| 14 |
+
bias=True,
|
| 15 |
+
hc_num_streams=1,
|
| 16 |
+
hc_num_fracs=1,
|
| 17 |
+
hc_disable=False,
|
| 18 |
+
mhc=False,
|
| 19 |
+
sinkhorn_iters=10,
|
| 20 |
+
sinkhorn_tau=0.05,
|
| 21 |
+
mhc_h_res_proj="sinkhorn",
|
| 22 |
+
ns_steps=5,
|
| 23 |
+
ns_eps=1e-7,
|
| 24 |
+
ns_coeffs=(3.0, -3.2, 1.2),
|
| 25 |
+
**kwargs,
|
| 26 |
+
):
|
| 27 |
+
super().__init__(**kwargs)
|
| 28 |
+
|
| 29 |
+
self.block_size = block_size
|
| 30 |
+
self.vocab_size = vocab_size
|
| 31 |
+
self.n_layer = n_layer
|
| 32 |
+
self.n_head = n_head
|
| 33 |
+
self.n_embd = n_embd
|
| 34 |
+
self.dropout = dropout
|
| 35 |
+
self.bias = bias
|
| 36 |
+
|
| 37 |
+
self.hc_num_streams = hc_num_streams
|
| 38 |
+
self.hc_num_fracs = hc_num_fracs
|
| 39 |
+
self.hc_disable = hc_disable
|
| 40 |
+
self.mhc = mhc
|
| 41 |
+
self.sinkhorn_iters = sinkhorn_iters
|
| 42 |
+
self.sinkhorn_tau = sinkhorn_tau
|
| 43 |
+
self.mhc_h_res_proj = mhc_h_res_proj
|
| 44 |
+
self.ns_steps = ns_steps
|
| 45 |
+
self.ns_eps = ns_eps
|
| 46 |
+
self.ns_coeffs = ns_coeffs
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_my_model.py
ADDED
|
@@ -0,0 +1,908 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 9 |
+
from transformers.modeling_outputs import CausalLMOutput
|
| 10 |
+
from typing import Callable
|
| 11 |
+
from transformers.generation.utils import GenerationMixin
|
| 12 |
+
from functools import partial
|
| 13 |
+
from random import randrange
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn, cat
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torch.nn import Module, Sequential
|
| 20 |
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
| 21 |
+
|
| 22 |
+
from einops import rearrange, repeat, reduce, einsum
|
| 23 |
+
from einops.layers.torch import Rearrange, Reduce
|
| 24 |
+
from .configuration_my_model import GPTConfig
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
ein notation:
|
| 28 |
+
b - batch
|
| 29 |
+
d - feature dimension
|
| 30 |
+
s - residual streams
|
| 31 |
+
t - residual streams + num branch inputs
|
| 32 |
+
f - number of fractions (division of feature dimension space)
|
| 33 |
+
v - number of views for branch input
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
# helper functions
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def exists(v):
|
| 40 |
+
return v is not None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def divisible_by(num, den):
|
| 44 |
+
return (num % den) == 0
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def default(v, d):
|
| 48 |
+
return v if exists(v) else d
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def identity(t):
|
| 52 |
+
return t
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def add(x, y):
|
| 56 |
+
return x + y
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def sinkhorn_log(logits, num_iters=10, tau=0.05):
|
| 60 |
+
n = logits.shape[-1]
|
| 61 |
+
Z = logits / tau
|
| 62 |
+
log_marginal = torch.full(
|
| 63 |
+
(n,), -math.log(n), device=logits.device, dtype=logits.dtype
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
u = torch.zeros(n, device=Z.device, dtype=Z.dtype)
|
| 67 |
+
v = torch.zeros(n, device=Z.device, dtype=Z.dtype)
|
| 68 |
+
|
| 69 |
+
for _ in range(num_iters):
|
| 70 |
+
u = log_marginal - torch.logsumexp(Z + v.unsqueeze(0), dim=1)
|
| 71 |
+
v = log_marginal - torch.logsumexp(Z + u.unsqueeze(1), dim=0)
|
| 72 |
+
|
| 73 |
+
return torch.exp(Z + u.unsqueeze(1) + v.unsqueeze(0)) * n
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def zeropower_via_newtonschulz(X, steps=5, eps=1e-7, coeffs=(3.0, -3.2, 1.2)):
|
| 77 |
+
a, b, c = coeffs
|
| 78 |
+
|
| 79 |
+
X = X / (X.norm() + eps)
|
| 80 |
+
|
| 81 |
+
transpose = False
|
| 82 |
+
if X.shape[0] > X.shape[1]:
|
| 83 |
+
X = X.T
|
| 84 |
+
transpose = True
|
| 85 |
+
|
| 86 |
+
for _ in range(steps):
|
| 87 |
+
A = X @ X.T
|
| 88 |
+
B = b * A + c * A @ A
|
| 89 |
+
X = a * X + B @ X
|
| 90 |
+
|
| 91 |
+
if transpose:
|
| 92 |
+
X = X.T
|
| 93 |
+
|
| 94 |
+
return X
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def orthostochastic_project(
|
| 98 |
+
logits, ns_steps=5, ns_eps=1e-7, ns_coeffs=(3.0, -3.2, 1.2)
|
| 99 |
+
):
|
| 100 |
+
O = zeropower_via_newtonschulz(logits, steps=ns_steps, eps=ns_eps, coeffs=ns_coeffs)
|
| 101 |
+
return O.square()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# main functions
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_expand_reduce_stream_functions(
|
| 108 |
+
num_streams, add_stream_embed=False, dim=None, disable=False
|
| 109 |
+
):
|
| 110 |
+
if num_streams == 1 or disable:
|
| 111 |
+
return (nn.Identity(), nn.Identity())
|
| 112 |
+
|
| 113 |
+
if add_stream_embed:
|
| 114 |
+
assert exists(dim), (
|
| 115 |
+
"`dim` must be passed into get_init_and_expand_reduce_stream_functions for returning an expansion function with stream embeddings added"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
expand_fn = StreamEmbed(num_streams, dim, expand_to_streams=True)
|
| 119 |
+
else:
|
| 120 |
+
expand_fn = Reduce(
|
| 121 |
+
pattern="b ... -> (b s) ...", reduction="repeat", s=num_streams
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
reduce_fn = Reduce(pattern="(b s) ... -> b ...", reduction="sum", s=num_streams)
|
| 125 |
+
|
| 126 |
+
return expand_fn, reduce_fn
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_init_and_expand_reduce_stream_functions(
|
| 130 |
+
num_streams, num_fracs=1, dim=None, add_stream_embed=False, disable=None
|
| 131 |
+
):
|
| 132 |
+
disable = default(disable, num_streams == 1 and num_fracs == 1)
|
| 133 |
+
|
| 134 |
+
hyper_conn_klass = HyperConnections if not disable else Residual
|
| 135 |
+
|
| 136 |
+
init_hyper_conn_fn = partial(hyper_conn_klass, num_streams, num_fracs=num_fracs)
|
| 137 |
+
expand_reduce_fns = get_expand_reduce_stream_functions(
|
| 138 |
+
num_streams, add_stream_embed=add_stream_embed, dim=dim, disable=disable
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
if exists(dim):
|
| 142 |
+
init_hyper_conn_fn = partial(init_hyper_conn_fn, dim=dim)
|
| 143 |
+
|
| 144 |
+
return (init_hyper_conn_fn, *expand_reduce_fns)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# norms
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class RMSNorm(Module):
|
| 151 |
+
def __init__(self, dim):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.scale = dim**0.5
|
| 154 |
+
self.gamma = nn.Parameter(torch.zeros(dim))
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
return F.normalize(x, dim=-1) * self.scale * (self.gamma + 1)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# main classes
|
| 161 |
+
|
| 162 |
+
# residual base class
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class Residual(Module):
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
*args,
|
| 169 |
+
branch: Module | None = None,
|
| 170 |
+
residual_transform: Module | None = None,
|
| 171 |
+
**kwargs,
|
| 172 |
+
):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.branch = branch
|
| 175 |
+
self.residual_transform = default(residual_transform, nn.Identity())
|
| 176 |
+
|
| 177 |
+
def width_connection(self, residuals):
|
| 178 |
+
return residuals, residuals, dict()
|
| 179 |
+
|
| 180 |
+
def depth_connection(
|
| 181 |
+
self,
|
| 182 |
+
branch_output,
|
| 183 |
+
residuals,
|
| 184 |
+
):
|
| 185 |
+
return branch_output + self.residual_transform(residuals)
|
| 186 |
+
|
| 187 |
+
def decorate_branch(self, branch: Callable):
|
| 188 |
+
assert not exists(self.branch), "branch was already wrapped on init"
|
| 189 |
+
|
| 190 |
+
def forward_and_add_residual(residual, *args, **kwargs):
|
| 191 |
+
branch_input, add_residual = self.forward(residual)
|
| 192 |
+
|
| 193 |
+
branch_output = branch(branch_input, *args, **kwargs)
|
| 194 |
+
|
| 195 |
+
residual = add_residual(branch_output)
|
| 196 |
+
|
| 197 |
+
return residual
|
| 198 |
+
|
| 199 |
+
return forward_and_add_residual
|
| 200 |
+
|
| 201 |
+
def forward(self, residuals, *branch_args, **branch_kwargs):
|
| 202 |
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
| 203 |
+
|
| 204 |
+
def add_residual_fn(branch_out):
|
| 205 |
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
| 206 |
+
|
| 207 |
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
| 208 |
+
|
| 209 |
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
| 210 |
+
|
| 211 |
+
if not exists(self.branch):
|
| 212 |
+
return branch_input, add_residual_fn
|
| 213 |
+
|
| 214 |
+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
| 215 |
+
|
| 216 |
+
return add_residual_fn(branch_output)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# hyper connection residual streams
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class HyperConnections(Module):
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
num_residual_streams,
|
| 226 |
+
*,
|
| 227 |
+
dim,
|
| 228 |
+
branch: Module | None = None,
|
| 229 |
+
layer_index=None,
|
| 230 |
+
tanh=True,
|
| 231 |
+
channel_first=False,
|
| 232 |
+
dropout=0.0,
|
| 233 |
+
residual_transform: Module
|
| 234 |
+
| None = None, # to support resnet blocks where dimension in not equal to dimension out - usually a residual conv
|
| 235 |
+
add_branch_out_to_residual=True, # will disable depth connections (weighted residual sum with beta) if set False
|
| 236 |
+
num_input_views=1, # allow for the branch module to receive multiple input views, dimension placed on the very left (before batch)
|
| 237 |
+
depth_residual_fn=add,
|
| 238 |
+
num_fracs=1, # https://arxiv.org/abs/2503.14125
|
| 239 |
+
mhc=False,
|
| 240 |
+
sinkhorn_iters=10,
|
| 241 |
+
sinkhorn_tau=0.05,
|
| 242 |
+
mhc_h_res_proj="sinkhorn",
|
| 243 |
+
ns_steps=5,
|
| 244 |
+
ns_eps=1e-7,
|
| 245 |
+
ns_coeffs=(3.0, -3.2, 1.2),
|
| 246 |
+
):
|
| 247 |
+
"""
|
| 248 |
+
Appendix J, Algorithm2 in - https://arxiv.org/abs/2409.19606
|
| 249 |
+
"""
|
| 250 |
+
super().__init__()
|
| 251 |
+
|
| 252 |
+
self.branch = branch
|
| 253 |
+
|
| 254 |
+
self.act = nn.Tanh() if tanh else nn.Identity()
|
| 255 |
+
|
| 256 |
+
# frac-connections paper - num_fracs > 1 will be the `m` in their paper https://arxiv.org/abs/2503.14125
|
| 257 |
+
|
| 258 |
+
assert num_fracs >= 1
|
| 259 |
+
|
| 260 |
+
self.num_fracs = num_fracs
|
| 261 |
+
self.has_fracs = num_fracs > 1
|
| 262 |
+
|
| 263 |
+
self.split_fracs = Rearrange("b ... (f d) -> b ... f d", f=num_fracs)
|
| 264 |
+
self.merge_fracs = Rearrange("b ... f d -> b ... (f d)")
|
| 265 |
+
|
| 266 |
+
assert divisible_by(dim, num_fracs), (
|
| 267 |
+
f"feature dimension ({dim}) must be divisible by the `num_fracs` ({num_fracs})"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
dim //= num_fracs # effective dim handled in dimension is feature dimension divided by num fractions
|
| 271 |
+
|
| 272 |
+
# they used layernorm in paper, but rmsnorm is fine given what we know now
|
| 273 |
+
|
| 274 |
+
self.norm = RMSNorm(dim)
|
| 275 |
+
|
| 276 |
+
assert num_residual_streams > 0, "`num_residual_streams` must be greater than 0"
|
| 277 |
+
|
| 278 |
+
self.num_residual_streams = num_residual_streams
|
| 279 |
+
init_residual_index = (
|
| 280 |
+
default(layer_index, randrange(num_residual_streams)) % num_residual_streams
|
| 281 |
+
) # just choose one random residual stream if layer index not given
|
| 282 |
+
|
| 283 |
+
# handle the parameter dimensions, which may require (num_residuals x num_fractions) - generalizing hyper + frac connections
|
| 284 |
+
|
| 285 |
+
num_residual_streams_fracs = num_residual_streams * num_fracs
|
| 286 |
+
num_input_views_fracs = num_input_views * num_fracs
|
| 287 |
+
|
| 288 |
+
# width num residual streams
|
| 289 |
+
|
| 290 |
+
assert num_input_views >= 1
|
| 291 |
+
self.num_input_views = num_input_views
|
| 292 |
+
|
| 293 |
+
# width connection
|
| 294 |
+
|
| 295 |
+
init_alpha0 = torch.zeros((num_residual_streams_fracs, num_input_views_fracs))
|
| 296 |
+
init_alpha0[init_residual_index, :] = 1.0
|
| 297 |
+
|
| 298 |
+
self.static_alpha = nn.Parameter(
|
| 299 |
+
cat((init_alpha0, torch.eye(num_residual_streams_fracs)), dim=1)
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
self.dynamic_alpha_fn = nn.Parameter(
|
| 303 |
+
torch.zeros(dim, num_residual_streams_fracs + num_input_views_fracs)
|
| 304 |
+
)
|
| 305 |
+
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
| 306 |
+
|
| 307 |
+
# depth connection related (beta)
|
| 308 |
+
|
| 309 |
+
self.add_branch_out_to_residual = add_branch_out_to_residual
|
| 310 |
+
|
| 311 |
+
if add_branch_out_to_residual:
|
| 312 |
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams_fracs))
|
| 313 |
+
|
| 314 |
+
dynamic_beta_shape = (
|
| 315 |
+
(dim,) if num_fracs == 1 else (dim, num_fracs)
|
| 316 |
+
) # preserve backwards compat
|
| 317 |
+
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dynamic_beta_shape))
|
| 318 |
+
|
| 319 |
+
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
| 320 |
+
|
| 321 |
+
# dropouts
|
| 322 |
+
|
| 323 |
+
self.dropout = nn.Dropout(dropout)
|
| 324 |
+
|
| 325 |
+
# channel first option
|
| 326 |
+
|
| 327 |
+
self.channel_first = channel_first
|
| 328 |
+
|
| 329 |
+
# maybe residual transform
|
| 330 |
+
|
| 331 |
+
self.residual_transform = default(residual_transform, nn.Identity())
|
| 332 |
+
|
| 333 |
+
# maybe custom depth connection residual function
|
| 334 |
+
# this is to prepare for gating the addition of the branch outputs to the residual streams
|
| 335 |
+
# needed for memory lanes a la RMT / LMM
|
| 336 |
+
|
| 337 |
+
self.depth_residual_fn = depth_residual_fn
|
| 338 |
+
|
| 339 |
+
self.mhc = mhc
|
| 340 |
+
self.sinkhorn_iters = sinkhorn_iters
|
| 341 |
+
self.sinkhorn_tau = sinkhorn_tau
|
| 342 |
+
self.mhc_h_res_proj = mhc_h_res_proj
|
| 343 |
+
self.ns_steps = ns_steps
|
| 344 |
+
self.ns_eps = ns_eps
|
| 345 |
+
self.ns_coeffs = ns_coeffs
|
| 346 |
+
|
| 347 |
+
if mhc:
|
| 348 |
+
assert num_fracs == 1, "mhc currently requires num_fracs = 1"
|
| 349 |
+
assert num_input_views == 1, "mhc currently requires num_input_views = 1"
|
| 350 |
+
assert mhc_h_res_proj in (
|
| 351 |
+
"sinkhorn",
|
| 352 |
+
"orthostochastic",
|
| 353 |
+
), "mhc_h_res_proj must be 'sinkhorn' or 'orthostochastic'"
|
| 354 |
+
|
| 355 |
+
H_res_init = torch.full((num_residual_streams, num_residual_streams), -8.0)
|
| 356 |
+
H_res_init.fill_diagonal_(0.0)
|
| 357 |
+
self.H_res_logits = nn.Parameter(H_res_init)
|
| 358 |
+
|
| 359 |
+
H_pre_init = torch.full((num_residual_streams,), -8.0)
|
| 360 |
+
H_pre_init[init_residual_index] = 0.0
|
| 361 |
+
self.H_pre_logits = nn.Parameter(H_pre_init)
|
| 362 |
+
|
| 363 |
+
if add_branch_out_to_residual:
|
| 364 |
+
self.H_post_logits = nn.Parameter(torch.zeros(num_residual_streams))
|
| 365 |
+
|
| 366 |
+
def width_connection(self, residuals):
|
| 367 |
+
streams = self.num_residual_streams
|
| 368 |
+
|
| 369 |
+
maybe_transformed_residuals = self.residual_transform(residuals)
|
| 370 |
+
|
| 371 |
+
# width connection
|
| 372 |
+
|
| 373 |
+
# handle channel first
|
| 374 |
+
|
| 375 |
+
if self.channel_first:
|
| 376 |
+
residuals = rearrange(residuals, "b d ... -> b ... d")
|
| 377 |
+
|
| 378 |
+
# split out fractions
|
| 379 |
+
|
| 380 |
+
residuals = self.split_fracs(residuals)
|
| 381 |
+
|
| 382 |
+
# split out streams
|
| 383 |
+
|
| 384 |
+
residuals = rearrange(residuals, "(b s) ... d -> b ... s d", s=streams)
|
| 385 |
+
|
| 386 |
+
if self.mhc:
|
| 387 |
+
residuals_mixed_source = maybe_transformed_residuals
|
| 388 |
+
|
| 389 |
+
if self.channel_first:
|
| 390 |
+
residuals_mixed_source = rearrange(
|
| 391 |
+
residuals_mixed_source, "b d ... -> b ... d"
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
residuals_mixed_source = self.split_fracs(residuals_mixed_source)
|
| 395 |
+
residuals_mixed_source = rearrange(
|
| 396 |
+
residuals_mixed_source, "(b s) ... d -> b ... s d", s=streams
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
if self.mhc_h_res_proj == "orthostochastic":
|
| 400 |
+
H_res = orthostochastic_project(
|
| 401 |
+
self.H_res_logits,
|
| 402 |
+
ns_steps=self.ns_steps,
|
| 403 |
+
ns_eps=self.ns_eps,
|
| 404 |
+
ns_coeffs=self.ns_coeffs,
|
| 405 |
+
)
|
| 406 |
+
else:
|
| 407 |
+
H_res = sinkhorn_log(
|
| 408 |
+
self.H_res_logits, self.sinkhorn_iters, self.sinkhorn_tau
|
| 409 |
+
)
|
| 410 |
+
H_pre = F.softmax(self.H_pre_logits, dim=-1)
|
| 411 |
+
|
| 412 |
+
H_post = None
|
| 413 |
+
if self.add_branch_out_to_residual:
|
| 414 |
+
H_post = F.softmax(self.H_post_logits, dim=-1)
|
| 415 |
+
|
| 416 |
+
residuals_mixed = einsum(
|
| 417 |
+
H_res, residuals_mixed_source, "s t, ... s d -> ... t d"
|
| 418 |
+
)
|
| 419 |
+
branch_input = einsum(H_pre, residuals, "s, ... s d -> ... d")
|
| 420 |
+
|
| 421 |
+
if getattr(self, "collect_stats", False):
|
| 422 |
+
with torch.no_grad():
|
| 423 |
+
stats = dict(
|
| 424 |
+
h_res_min=H_res.min(),
|
| 425 |
+
h_res_row_sum=H_res.sum(dim=-1).mean(),
|
| 426 |
+
h_res_col_sum=H_res.sum(dim=-2).mean(),
|
| 427 |
+
h_pre_min=H_pre.min(),
|
| 428 |
+
)
|
| 429 |
+
if H_post is not None:
|
| 430 |
+
stats["h_post_min"] = H_post.min()
|
| 431 |
+
self.last_stats = {k: v.detach() for k, v in stats.items()}
|
| 432 |
+
|
| 433 |
+
if self.channel_first:
|
| 434 |
+
branch_input = rearrange(branch_input, "b ... d -> b d ...")
|
| 435 |
+
|
| 436 |
+
branch_input = self.merge_fracs(branch_input)
|
| 437 |
+
|
| 438 |
+
return (
|
| 439 |
+
branch_input,
|
| 440 |
+
maybe_transformed_residuals,
|
| 441 |
+
dict(beta=H_post, residuals_mixed=residuals_mixed),
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# norm
|
| 445 |
+
|
| 446 |
+
normed = self.norm(residuals)
|
| 447 |
+
|
| 448 |
+
# alpha for weighted sum of residuals going into branch
|
| 449 |
+
|
| 450 |
+
wc_weight = self.act(normed @ self.dynamic_alpha_fn)
|
| 451 |
+
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
|
| 452 |
+
|
| 453 |
+
static_alpha = rearrange(self.static_alpha, "(f s) d -> f s d", s=streams)
|
| 454 |
+
|
| 455 |
+
alpha = dynamic_alpha + static_alpha
|
| 456 |
+
|
| 457 |
+
alpha = self.split_fracs(
|
| 458 |
+
alpha
|
| 459 |
+
) # (batch, seq, fracs1, streams, fracs2, input + residual streams)
|
| 460 |
+
|
| 461 |
+
# beta for weights from branch output back to residual streams
|
| 462 |
+
|
| 463 |
+
beta = None
|
| 464 |
+
|
| 465 |
+
if self.add_branch_out_to_residual:
|
| 466 |
+
dc_weight = self.act(normed @ self.dynamic_beta_fn)
|
| 467 |
+
|
| 468 |
+
if not self.has_fracs:
|
| 469 |
+
dc_weight = rearrange(dc_weight, "... -> ... 1")
|
| 470 |
+
|
| 471 |
+
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
| 472 |
+
|
| 473 |
+
static_beta = rearrange(self.static_beta, "... (s f) -> ... s f", s=streams)
|
| 474 |
+
|
| 475 |
+
beta = dynamic_beta + static_beta
|
| 476 |
+
|
| 477 |
+
if getattr(self, "collect_stats", False):
|
| 478 |
+
with torch.no_grad():
|
| 479 |
+
num_input_views_fracs = self.num_input_views * self.num_fracs
|
| 480 |
+
alpha_branch = alpha[..., :num_input_views_fracs]
|
| 481 |
+
alpha_residual = alpha[..., num_input_views_fracs:]
|
| 482 |
+
alpha_branch_abs_mean = alpha_branch.abs().mean()
|
| 483 |
+
alpha_residual_abs_mean = alpha_residual.abs().mean()
|
| 484 |
+
stats = dict(
|
| 485 |
+
alpha_branch_mean=alpha_branch.mean(),
|
| 486 |
+
alpha_branch_abs_mean=alpha_branch_abs_mean,
|
| 487 |
+
alpha_residual_mean=alpha_residual.mean(),
|
| 488 |
+
alpha_residual_abs_mean=alpha_residual_abs_mean,
|
| 489 |
+
alpha_branch_residual_ratio=alpha_branch_abs_mean
|
| 490 |
+
/ (alpha_residual_abs_mean + 1e-8),
|
| 491 |
+
)
|
| 492 |
+
if beta is not None:
|
| 493 |
+
stats.update(
|
| 494 |
+
beta_mean=beta.mean(),
|
| 495 |
+
beta_abs_mean=beta.abs().mean(),
|
| 496 |
+
beta_min=beta.min(),
|
| 497 |
+
beta_max=beta.max(),
|
| 498 |
+
)
|
| 499 |
+
self.last_stats = {k: v.detach() for k, v in stats.items()}
|
| 500 |
+
|
| 501 |
+
mix_h = einsum(alpha, residuals, "... f1 s f2 t, ... f1 s d -> ... f2 t d")
|
| 502 |
+
|
| 503 |
+
if self.num_input_views == 1:
|
| 504 |
+
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
| 505 |
+
else:
|
| 506 |
+
branch_input, residuals = (
|
| 507 |
+
mix_h[..., : self.num_input_views, :],
|
| 508 |
+
mix_h[..., self.num_input_views :, :],
|
| 509 |
+
)
|
| 510 |
+
branch_input = rearrange(branch_input, "b ... v d -> v b ... d")
|
| 511 |
+
|
| 512 |
+
if self.channel_first:
|
| 513 |
+
branch_input = rearrange(branch_input, "b ... d -> b d ...")
|
| 514 |
+
|
| 515 |
+
# maybe merge fractions back
|
| 516 |
+
|
| 517 |
+
branch_input = self.merge_fracs(branch_input)
|
| 518 |
+
|
| 519 |
+
return branch_input, maybe_transformed_residuals, dict(beta=beta)
|
| 520 |
+
|
| 521 |
+
def depth_connection(self, branch_output, residuals, *, beta, residuals_mixed=None):
|
| 522 |
+
assert self.add_branch_out_to_residual
|
| 523 |
+
|
| 524 |
+
# maybe split fractions
|
| 525 |
+
|
| 526 |
+
branch_output = self.split_fracs(branch_output)
|
| 527 |
+
|
| 528 |
+
# 'depth' connection
|
| 529 |
+
|
| 530 |
+
if self.channel_first:
|
| 531 |
+
branch_output = rearrange(branch_output, "b d ... -> b ... d")
|
| 532 |
+
|
| 533 |
+
if self.mhc:
|
| 534 |
+
assert residuals_mixed is not None
|
| 535 |
+
assert beta is not None
|
| 536 |
+
|
| 537 |
+
branch_to_streams = einsum(branch_output, beta, "b ... d, s -> b ... s d")
|
| 538 |
+
output = residuals_mixed + branch_to_streams
|
| 539 |
+
output = rearrange(output, "b ... s d -> (b s) ... d")
|
| 540 |
+
|
| 541 |
+
output = self.merge_fracs(output)
|
| 542 |
+
|
| 543 |
+
if self.channel_first:
|
| 544 |
+
output = rearrange(output, "b ... d -> b d ...")
|
| 545 |
+
|
| 546 |
+
return self.dropout(output)
|
| 547 |
+
|
| 548 |
+
output = einsum(
|
| 549 |
+
branch_output, beta, "b ... f1 d, b ... f1 s f2 -> b ... f2 s d"
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
output = rearrange(output, "b ... s d -> (b s) ... d")
|
| 553 |
+
|
| 554 |
+
# merge merge back fractions
|
| 555 |
+
|
| 556 |
+
output = self.merge_fracs(output)
|
| 557 |
+
|
| 558 |
+
# channel first
|
| 559 |
+
|
| 560 |
+
if self.channel_first:
|
| 561 |
+
output = rearrange(output, "b ... d -> b d ...")
|
| 562 |
+
|
| 563 |
+
residuals = self.depth_residual_fn(output, residuals)
|
| 564 |
+
|
| 565 |
+
return self.dropout(residuals)
|
| 566 |
+
|
| 567 |
+
def decorate_branch(self, branch: Callable):
|
| 568 |
+
assert not exists(self.branch), "branch was already wrapped on init"
|
| 569 |
+
|
| 570 |
+
def forward_and_add_residual(residual, *args, **kwargs):
|
| 571 |
+
branch_input, add_residual = self.forward(residual)
|
| 572 |
+
|
| 573 |
+
branch_output = branch(branch_input, *args, **kwargs)
|
| 574 |
+
|
| 575 |
+
residual = add_residual(branch_output)
|
| 576 |
+
|
| 577 |
+
return residual
|
| 578 |
+
|
| 579 |
+
return forward_and_add_residual
|
| 580 |
+
|
| 581 |
+
def forward(self, residuals, *branch_args, **branch_kwargs):
|
| 582 |
+
branch_input, residuals, residual_kwargs = self.width_connection(residuals)
|
| 583 |
+
|
| 584 |
+
def add_residual_fn(branch_out):
|
| 585 |
+
if not self.add_branch_out_to_residual:
|
| 586 |
+
return branch_out
|
| 587 |
+
|
| 588 |
+
(branch_out, *rest), tree_spec = tree_flatten(branch_out)
|
| 589 |
+
|
| 590 |
+
branch_out = self.depth_connection(branch_out, residuals, **residual_kwargs)
|
| 591 |
+
|
| 592 |
+
return tree_unflatten((branch_out, *rest), tree_spec)
|
| 593 |
+
|
| 594 |
+
if not exists(self.branch):
|
| 595 |
+
return branch_input, add_residual_fn
|
| 596 |
+
|
| 597 |
+
branch_output = self.branch(branch_input, *branch_args, **branch_kwargs)
|
| 598 |
+
|
| 599 |
+
return add_residual_fn(branch_output)
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
HyperConnections.get_expand_reduce_stream_functions = staticmethod(
|
| 603 |
+
get_expand_reduce_stream_functions
|
| 604 |
+
)
|
| 605 |
+
HyperConnections.get_init_and_expand_reduce_stream_functions = staticmethod(
|
| 606 |
+
get_init_and_expand_reduce_stream_functions
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
# stream embed
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class StreamEmbed(Module):
|
| 613 |
+
def __init__(self, num_streams, dim, channel_first=False, expand_to_streams=False):
|
| 614 |
+
super().__init__()
|
| 615 |
+
self.channel_first = channel_first
|
| 616 |
+
self.num_streams = num_streams
|
| 617 |
+
|
| 618 |
+
self.expand_to_streams = expand_to_streams
|
| 619 |
+
self.stream_embed = nn.Parameter(torch.zeros(num_streams, dim))
|
| 620 |
+
|
| 621 |
+
def forward(self, residuals):
|
| 622 |
+
if self.expand_to_streams:
|
| 623 |
+
residuals = repeat(residuals, "b ... -> (b s) ...", s=self.num_streams)
|
| 624 |
+
|
| 625 |
+
if self.channel_first:
|
| 626 |
+
residuals = rearrange(
|
| 627 |
+
residuals, "(b s) d ... -> b ... s d", s=self.num_streams
|
| 628 |
+
)
|
| 629 |
+
else:
|
| 630 |
+
residuals = rearrange(
|
| 631 |
+
residuals, "(b s) ... d -> b ... s d", s=self.num_streams
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
residuals = residuals + self.stream_embed
|
| 635 |
+
|
| 636 |
+
if self.channel_first:
|
| 637 |
+
residuals = rearrange(
|
| 638 |
+
residuals, "b ... s d -> (b s) d ...", s=self.num_streams
|
| 639 |
+
)
|
| 640 |
+
else:
|
| 641 |
+
residuals = rearrange(
|
| 642 |
+
residuals, "b ... s d -> (b s) ... d", s=self.num_streams
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
return residuals
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
# attention pool - taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
class AttentionPoolReduceStream(Module):
|
| 652 |
+
def __init__(self, num_streams, dim, channel_first=False):
|
| 653 |
+
super().__init__()
|
| 654 |
+
self.num_streams = num_streams
|
| 655 |
+
self.channel_first = channel_first
|
| 656 |
+
|
| 657 |
+
self.to_attn_logits = nn.Linear(dim, dim, bias=False)
|
| 658 |
+
self.to_attn_logits.weight.data.copy_(torch.eye(dim))
|
| 659 |
+
|
| 660 |
+
def forward(self, residuals):
|
| 661 |
+
if self.channel_first:
|
| 662 |
+
residuals = rearrange(
|
| 663 |
+
residuals, "(b s) d ... -> b ... s d", s=self.num_streams
|
| 664 |
+
)
|
| 665 |
+
else:
|
| 666 |
+
residuals = rearrange(
|
| 667 |
+
residuals, "(b s) ... d -> b ... s d", s=self.num_streams
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
attn_logits = self.to_attn_logits(residuals)
|
| 671 |
+
attn = attn_logits.softmax(dim=-2)
|
| 672 |
+
|
| 673 |
+
residuals = reduce(residuals * attn, "b ... s d -> b ... d", "sum")
|
| 674 |
+
|
| 675 |
+
if self.channel_first:
|
| 676 |
+
residuals = rearrange(residuals, "b ... d -> b d ...")
|
| 677 |
+
|
| 678 |
+
return residuals
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
class CausalSelfAttention(nn.Module):
|
| 682 |
+
def __init__(self, config):
|
| 683 |
+
super().__init__()
|
| 684 |
+
|
| 685 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
|
| 686 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
| 687 |
+
self.c_proj.NANOGPT_SCALE_INIT = 1
|
| 688 |
+
|
| 689 |
+
self.n_head = config.n_head
|
| 690 |
+
self.n_embd = config.n_embd
|
| 691 |
+
|
| 692 |
+
self.register_buffer(
|
| 693 |
+
"bias",
|
| 694 |
+
torch.tril(torch.ones(config.block_size, config.block_size))
|
| 695 |
+
.view(1, 1, config.block_size, config.block_size),
|
| 696 |
+
persistent=False,
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
def forward(self, x):
|
| 700 |
+
B, T, C = x.size()
|
| 701 |
+
|
| 702 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 703 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
| 704 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
| 705 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
| 706 |
+
|
| 707 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
| 708 |
+
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
|
| 709 |
+
att = F.softmax(att, dim=-1)
|
| 710 |
+
|
| 711 |
+
y = att @ v
|
| 712 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 713 |
+
return self.c_proj(y)
|
| 714 |
+
class MLP(nn.Module):
|
| 715 |
+
def __init__(self, config):
|
| 716 |
+
super().__init__()
|
| 717 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
|
| 718 |
+
self.gelu = nn.GELU(approximate="tanh")
|
| 719 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
|
| 720 |
+
self.c_proj.NANOGPT_SCALE_INIT = 1
|
| 721 |
+
|
| 722 |
+
def forward(self, x):
|
| 723 |
+
return self.c_proj(self.gelu(self.c_fc(x)))
|
| 724 |
+
class AttnBranch(nn.Module):
|
| 725 |
+
def __init__(self, norm, attn):
|
| 726 |
+
super().__init__()
|
| 727 |
+
self.norm = norm
|
| 728 |
+
self.attn = attn
|
| 729 |
+
|
| 730 |
+
def forward(self, x):
|
| 731 |
+
return self.attn(self.norm(x))
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
class Block(nn.Module):
|
| 735 |
+
def __init__(self, config, layer_idx, init_hc):
|
| 736 |
+
super().__init__()
|
| 737 |
+
|
| 738 |
+
self.ln_1 = nn.LayerNorm(config.n_embd)
|
| 739 |
+
self.ln_2 = nn.LayerNorm(config.n_embd)
|
| 740 |
+
|
| 741 |
+
self.attn = CausalSelfAttention(config)
|
| 742 |
+
self.mlp = MLP(config)
|
| 743 |
+
|
| 744 |
+
self.attn_branch = AttnBranch(self.ln_1, self.attn)
|
| 745 |
+
|
| 746 |
+
hc_kwargs = dict(
|
| 747 |
+
mhc=config.mhc,
|
| 748 |
+
sinkhorn_iters=config.sinkhorn_iters,
|
| 749 |
+
sinkhorn_tau=config.sinkhorn_tau,
|
| 750 |
+
mhc_h_res_proj=config.mhc_h_res_proj,
|
| 751 |
+
ns_steps=config.ns_steps,
|
| 752 |
+
ns_eps=config.ns_eps,
|
| 753 |
+
ns_coeffs=config.ns_coeffs,
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
self.hc_attn = init_hc(
|
| 757 |
+
dim=config.n_embd,
|
| 758 |
+
branch=self.attn_branch,
|
| 759 |
+
layer_index=layer_idx * 2,
|
| 760 |
+
**hc_kwargs,
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
self.hc_mlp = init_hc(
|
| 764 |
+
dim=config.n_embd,
|
| 765 |
+
branch=nn.Sequential(self.ln_2, self.mlp),
|
| 766 |
+
layer_index=layer_idx * 2 + 1,
|
| 767 |
+
**hc_kwargs,
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
def forward(self, x):
|
| 771 |
+
x = self.hc_attn(x)
|
| 772 |
+
x = self.hc_mlp(x)
|
| 773 |
+
return x
|
| 774 |
+
class GPTConfig(PretrainedConfig):
|
| 775 |
+
model_type = "custom_gpt"
|
| 776 |
+
|
| 777 |
+
def __init__(
|
| 778 |
+
self,
|
| 779 |
+
block_size=1024,
|
| 780 |
+
vocab_size=50304,
|
| 781 |
+
n_layer=12,
|
| 782 |
+
n_head=12,
|
| 783 |
+
n_embd=768,
|
| 784 |
+
dropout=0.0,
|
| 785 |
+
bias=True,
|
| 786 |
+
hc_num_streams=1,
|
| 787 |
+
hc_num_fracs=1,
|
| 788 |
+
hc_disable=False,
|
| 789 |
+
mhc=False,
|
| 790 |
+
sinkhorn_iters=10,
|
| 791 |
+
sinkhorn_tau=0.05,
|
| 792 |
+
mhc_h_res_proj="sinkhorn",
|
| 793 |
+
ns_steps=5,
|
| 794 |
+
ns_eps=1e-7,
|
| 795 |
+
ns_coeffs=(3.0, -3.2, 1.2),
|
| 796 |
+
**kwargs,
|
| 797 |
+
):
|
| 798 |
+
super().__init__(**kwargs)
|
| 799 |
+
|
| 800 |
+
self.block_size = block_size
|
| 801 |
+
self.vocab_size = vocab_size
|
| 802 |
+
self.n_layer = n_layer
|
| 803 |
+
self.n_head = n_head
|
| 804 |
+
self.n_embd = n_embd
|
| 805 |
+
self.dropout = dropout
|
| 806 |
+
self.bias = bias
|
| 807 |
+
|
| 808 |
+
self.hc_num_streams = hc_num_streams
|
| 809 |
+
self.hc_num_fracs = hc_num_fracs
|
| 810 |
+
self.hc_disable = hc_disable
|
| 811 |
+
self.mhc = mhc
|
| 812 |
+
self.sinkhorn_iters = sinkhorn_iters
|
| 813 |
+
self.sinkhorn_tau = sinkhorn_tau
|
| 814 |
+
self.mhc_h_res_proj = mhc_h_res_proj
|
| 815 |
+
self.ns_steps = ns_steps
|
| 816 |
+
self.ns_eps = ns_eps
|
| 817 |
+
self.ns_coeffs = ns_coeffs
|
| 818 |
+
|
| 819 |
+
# 🔑 HF compatibility aliases
|
| 820 |
+
self.num_hidden_layers = n_layer
|
| 821 |
+
self.num_attention_heads = n_head
|
| 822 |
+
self.hidden_size = n_embd
|
| 823 |
+
self.max_position_embeddings = block_size
|
| 824 |
+
|
| 825 |
+
class GPT(PreTrainedModel, GenerationMixin):
|
| 826 |
+
config_class = GPTConfig
|
| 827 |
+
# config_class = MyGPTConfig
|
| 828 |
+
|
| 829 |
+
def __init__(self, config):
|
| 830 |
+
super().__init__(config)
|
| 831 |
+
|
| 832 |
+
init_hc, expand_stream, reduce_stream = (
|
| 833 |
+
get_init_and_expand_reduce_stream_functions(
|
| 834 |
+
config.hc_num_streams,
|
| 835 |
+
num_fracs=config.hc_num_fracs,
|
| 836 |
+
disable=config.hc_disable,
|
| 837 |
+
)
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
self.expand_stream = expand_stream
|
| 841 |
+
self.reduce_stream = reduce_stream
|
| 842 |
+
|
| 843 |
+
self.transformer = nn.ModuleDict(
|
| 844 |
+
dict(
|
| 845 |
+
wte=nn.Embedding(config.vocab_size, config.n_embd),
|
| 846 |
+
wpe=nn.Embedding(config.block_size, config.n_embd),
|
| 847 |
+
h=nn.ModuleList(
|
| 848 |
+
[Block(config, i, init_hc) for i in range(config.n_layer)]
|
| 849 |
+
),
|
| 850 |
+
ln_f=nn.LayerNorm(config.n_embd),
|
| 851 |
+
)
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 855 |
+
self.transformer.wte.weight = self.lm_head.weight
|
| 856 |
+
|
| 857 |
+
self.post_init()
|
| 858 |
+
|
| 859 |
+
def prepare_inputs_for_generation(
|
| 860 |
+
self,
|
| 861 |
+
input_ids,
|
| 862 |
+
past_key_values=None,
|
| 863 |
+
**kwargs,
|
| 864 |
+
):
|
| 865 |
+
# We do NOT use KV cache yet, so always feed full sequence
|
| 866 |
+
return {
|
| 867 |
+
"input_ids": input_ids,
|
| 868 |
+
"past_key_values": None,
|
| 869 |
+
}
|
| 870 |
+
|
| 871 |
+
def forward(
|
| 872 |
+
self,
|
| 873 |
+
input_ids=None,
|
| 874 |
+
attention_mask=None, # 👈 ADD THIS
|
| 875 |
+
labels=None,
|
| 876 |
+
past_key_values=None,
|
| 877 |
+
use_cache=None,
|
| 878 |
+
**kwargs,
|
| 879 |
+
):
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
b, t = input_ids.size()
|
| 883 |
+
assert t <= self.config.block_size
|
| 884 |
+
|
| 885 |
+
pos = torch.arange(0, t, device=input_ids.device).unsqueeze(0)
|
| 886 |
+
|
| 887 |
+
x = self.transformer.wte(input_ids) + self.transformer.wpe(pos)
|
| 888 |
+
x = self.expand_stream(x)
|
| 889 |
+
|
| 890 |
+
for block in self.transformer.h:
|
| 891 |
+
x = block(x)
|
| 892 |
+
|
| 893 |
+
x = self.transformer.ln_f(x)
|
| 894 |
+
x = self.reduce_stream(x)
|
| 895 |
+
|
| 896 |
+
logits = self.lm_head(x)
|
| 897 |
+
|
| 898 |
+
loss = None
|
| 899 |
+
if labels is not None:
|
| 900 |
+
loss = F.cross_entropy(
|
| 901 |
+
logits.view(-1, logits.size(-1)),
|
| 902 |
+
labels.view(-1),
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
return CausalLMOutput(
|
| 906 |
+
loss=loss,
|
| 907 |
+
logits=logits,
|
| 908 |
+
)
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df198621f8d5ee6f68a316ddd5730b5816c022ad1ed6daaa309ec09bc0b79e7c
|
| 3 |
+
size 1520347603
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "<|endoftext|>",
|
| 3 |
+
"eos_token": "<|endoftext|>",
|
| 4 |
+
"unk_token": "<|endoftext|>"
|
| 5 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"added_tokens_decoder": {
|
| 4 |
+
"50256": {
|
| 5 |
+
"content": "<|endoftext|>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": true,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"special": true
|
| 11 |
+
}
|
| 12 |
+
},
|
| 13 |
+
"bos_token": "<|endoftext|>",
|
| 14 |
+
"clean_up_tokenization_spaces": false,
|
| 15 |
+
"eos_token": "<|endoftext|>",
|
| 16 |
+
"extra_special_tokens": {},
|
| 17 |
+
"model_max_length": 1024,
|
| 18 |
+
"tokenizer_class": "GPT2Tokenizer",
|
| 19 |
+
"unk_token": "<|endoftext|>"
|
| 20 |
+
}
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|