booydar commited on
Commit
bae3bed
·
verified ·
1 Parent(s): 734465d

Upload RMT

Browse files
Files changed (3) hide show
  1. config.json +5 -1
  2. language_modeling.py +105 -0
  3. model.safetensors +2 -2
config.json CHANGED
@@ -3,6 +3,10 @@
3
  "architectures": [
4
  "RMT"
5
  ],
 
 
 
 
6
  "base_model_name": "HuggingFaceTB/SmolLM2-135M",
7
  "bos_token_id": 0,
8
  "eos_token_id": 0,
@@ -13,5 +17,5 @@
13
  "recurrent_wrapper_cls": "modeling_rmt.experimental:RecurrentWrapperNoSegmentationGenerate",
14
  "think_token_id": 8,
15
  "torch_dtype": "float32",
16
- "transformers_version": "4.53.1"
17
  }
 
3
  "architectures": [
4
  "RMT"
5
  ],
6
+ "auto_map": {
7
+ "AutoConfig": "language_modeling.RMTConfig",
8
+ "AutoModel": "language_modeling.RMT"
9
+ },
10
  "base_model_name": "HuggingFaceTB/SmolLM2-135M",
11
  "bos_token_id": 0,
12
  "eos_token_id": 0,
 
17
  "recurrent_wrapper_cls": "modeling_rmt.experimental:RecurrentWrapperNoSegmentationGenerate",
18
  "think_token_id": 8,
19
  "torch_dtype": "float32",
20
+ "transformers_version": "4.54.1"
21
  }
language_modeling.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from transformers import PreTrainedModel, PretrainedConfig
3
+ # from lm_experiments_tools.utils import get_cls_by_name
4
+
5
+
6
+ def get_cls_by_name(name: str) -> type:
7
+ """Get class by its name and module path.
8
+
9
+ Args:
10
+ name (str): e.g., transfomers:T5ForConditionalGeneration, modeling_t5:my_class
11
+
12
+ Returns:
13
+ type: found class for `name`
14
+ """
15
+ module_name, cls_name = name.split(':')
16
+ return getattr(importlib.import_module(module_name), cls_name)
17
+
18
+
19
+ class RMTConfig(PretrainedConfig):
20
+ model_type = "rmt"
21
+
22
+ def __init__(self,
23
+ base_model_name="HuggingFaceTB/SmolLM2-135M",
24
+ num_mem_tokens=16,
25
+ max_n_segments=10,
26
+ think_token_id=None,
27
+ answer_token_id=None,
28
+ bos_token_id=None,
29
+ eos_token_id=None,
30
+ memory_cell_cls='modeling_rmt.language_modeling:MemoryCell',
31
+ recurrent_wrapper_cls='modeling_rmt.experimental:RecurrentWrapperNoSegmentationGenerate',
32
+ **kwargs):
33
+ super().__init__(**kwargs)
34
+ self.base_model_name = base_model_name
35
+ self.num_mem_tokens = num_mem_tokens
36
+ self.max_n_segments = max_n_segments
37
+ self.think_token_id = think_token_id
38
+ self.answer_token_id = answer_token_id
39
+ self.bos_token_id = bos_token_id
40
+ self.eos_token_id = eos_token_id
41
+ self.memory_cell_cls = memory_cell_cls
42
+ self.recurrent_wrapper_cls = recurrent_wrapper_cls
43
+
44
+ def get(self, attr: str, default=None):
45
+ if hasattr(self, attr):
46
+ return getattr(self, attr)
47
+ else:
48
+ return default
49
+
50
+
51
+ class RMT(PreTrainedModel):
52
+ config_class = RMTConfig
53
+
54
+ def __init__(self, config: RMTConfig):
55
+ super().__init__(config)
56
+ from transformers import AutoConfig, AutoModelForCausalLM
57
+ base_config = AutoConfig.from_pretrained(config.base_model_name)
58
+ base_model = AutoModelForCausalLM.from_config(base_config)
59
+
60
+ memory_cell_cls = get_cls_by_name(config.memory_cell_cls)
61
+ recurrent_wrapper_cls = get_cls_by_name(config.recurrent_wrapper_cls)
62
+
63
+ self.rmt_config = config
64
+ memory_cell = memory_cell_cls(base_model, num_mem_tokens=config.num_mem_tokens)
65
+ self.rmt = recurrent_wrapper_cls(
66
+ memory_cell,
67
+ max_n_segments=config.max_n_segments,
68
+ think_token_id=config.think_token_id,
69
+ answer_token_id=config.answer_token_id,
70
+ bos_token_id=config.bos_token_id,
71
+ eos_token_id=config.eos_token_id
72
+ )
73
+
74
+ def forward(self, *args, **kwargs):
75
+ return self.rmt(*args, **kwargs)
76
+
77
+ def generate(self, *args, **kwargs):
78
+ return self.rmt.generate(*args, **kwargs)
79
+
80
+ def load_state_dict(self, state_dict, strict=True, assign=False):
81
+ try:
82
+ return super().load_state_dict(state_dict, strict, assign)
83
+ except RuntimeError:
84
+ print("Failed to load state, retrying with RMT loader.")
85
+ self.rmt.load_state_dict(state_dict, strict=True, assign=assign)
86
+ print("Success!")
87
+
88
+ @classmethod
89
+ def from_pretrained(cls, pretrained_model_name_or_path, config=None, *args, **kwargs):
90
+ if config is None:
91
+ config = RMTConfig.from_pretrained(pretrained_model_name_or_path)
92
+ model = cls(config)
93
+
94
+ import os
95
+ from safetensors import safe_open
96
+ from collections import OrderedDict
97
+ safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
98
+ state_dict = OrderedDict()
99
+ with safe_open(safetensors_path, framework="pt", device="cpu") as f:
100
+ for key in f.keys():
101
+ tensor = f.get_tensor(key)
102
+ state_dict[key] = tensor
103
+ model.load_state_dict(state_dict, strict=False)
104
+
105
+ return model
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1b827dad8b86cc6d24bc2141e73a8f85040d729f1f9a5c1f04aff459397a09ee
3
- size 538170208
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9aaa76bc4eb456e998ecc096bdd5c05f9b83662ec42362fba6eb4580a839ea9
3
+ size 269140352