NeTS-lab commited on
Commit
9e31d55
·
verified ·
1 Parent(s): a86f551

Upload EMG model with MorPiece tokenizer

Browse files
README.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ library_name: transformers
4
+ tags:
5
+ - emg
6
+ - morphology
7
+ - language-model
8
+ - causal-lm
9
+ - morpiece-tokenizer
10
+ license: apache-2.0
11
+ pipeline_tag: text-generation
12
+ ---
13
+
14
+ # EMG Language Model
15
+
16
+ This is an EMG (Enhanced Morphological Generation) language model with MorPiece tokenizer.
17
+
18
+ ## Model Details
19
+
20
+ - **Model Type**: Causal Language Model
21
+ - **Architecture**: EMG with morphological awareness
22
+ - **Tokenizer**: MorPiece (morphology-aware tokenization)
23
+ - **Parameters**: 79.75M
24
+ - **Vocabulary Size**: 60001
25
+
26
+ ## Usage
27
+
28
+ ```python
29
+ from transformers import AutoTokenizer, AutoModelForCausalLM
30
+
31
+ # Load model and tokenizer
32
+ tokenizer = AutoTokenizer.from_pretrained("your-username/your-model-name", trust_remote_code=True)
33
+ model = AutoModelForCausalLM.from_pretrained("your-username/your-model-name", trust_remote_code=True)
34
+
35
+ # Generate text
36
+ input_text = "The future of AI is"
37
+ inputs = tokenizer(input_text, return_tensors="pt")
38
+ outputs = model.generate(**inputs, max_length=50)
39
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
+ print(generated_text)
41
+ ```
42
+
43
+ ## Model Architecture
44
+
45
+ The EMG model uses morphological awareness for better language understanding and generation.
46
+ The MorPiece tokenizer provides morphology-aware tokenization that better handles word formations.
47
+
48
+ ## Training
49
+
50
+ This model was trained on conversational data with morphological enhancement.
51
+
52
+ ## Limitations
53
+
54
+ - This model is designed for research purposes
55
+ - May not perform optimally on all downstream tasks without fine-tuning
56
+ - Requires trust_remote_code=True due to custom architecture
57
+
58
+ ## Citation
59
+
60
+ If you use this model, please cite the original EMG paper and implementation.
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "EMGForCausalLM"
4
+ ],
5
+ "dropout": 0.01,
6
+ "embedding_dim": 650,
7
+ "hidden_dim": 650,
8
+ "model_type": "emg",
9
+ "num_layers": 1,
10
+ "pad_token_id": 60004,
11
+ "torch_dtype": "float32",
12
+ "transformers_version": "4.52.3",
13
+ "use_gradient_checkpointing": false,
14
+ "use_layer_norm": true,
15
+ "vocab_size": 60001,
16
+ "auto_map": {
17
+ "AutoConfig": "modeling_emg.EMGConfig",
18
+ "AutoModel": "modeling_emg.EMGLanguageModel",
19
+ "AutoModelForCausalLM": "modeling_emg.EMGForCausalLM",
20
+ "AutoTokenizer": "modeling_emg.MorPieceTokenizer"
21
+ }
22
+ }
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.52.3"
4
+ }
model_eMG_simplified.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import PreTrainedModel, PretrainedConfig
6
+
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ print(f"Using device: {device}")
12
+
13
+ # ===================== OPTIMIZED EMG MODEL =====================
14
+
15
+
16
+ class OptimizedEMGCell(nn.Module):
17
+ def __init__(self, input_size, hidden_size, dropout_rate=0.1, use_layer_norm=False):
18
+ super(OptimizedEMGCell, self).__init__()
19
+ self.input_size = input_size
20
+ self.hidden_size = hidden_size
21
+ self.use_layer_norm = use_layer_norm
22
+ self.clamp_min = -1
23
+ self.clamp_max = 1
24
+
25
+ # Fused linear transformations for better efficiency
26
+ self.input_transform_linear = nn.Linear(input_size, hidden_size * 2)
27
+ self.hidden_transform_linear = nn.Linear(hidden_size, hidden_size * 2)
28
+
29
+ # SIMPLIFIED: Use standard dropout instead of variational
30
+ self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None
31
+
32
+ # Layer normalization for training stability
33
+ if use_layer_norm:
34
+ self.input_norm = nn.LayerNorm(hidden_size)
35
+ self.hidden_norm = nn.LayerNorm(hidden_size)
36
+ self.cell_norm = nn.LayerNorm(hidden_size)
37
+
38
+ self.init_weights()
39
+
40
+ def init_weights(self):
41
+ for linear in [self.input_transform_linear, self.hidden_transform_linear]:
42
+ # Use smaller initialization for RNN stability
43
+ nn.init.uniform_(linear.weight, -0.1, 0.1)
44
+ nn.init.zeros_(linear.bias)
45
+
46
+ def forward(self, input, hidden):
47
+ h_prev, c_prev = hidden
48
+
49
+ # Project input and hidden states
50
+ input_connections = self.input_transform_linear(input)
51
+ hidden_connections = self.hidden_transform_linear(h_prev)
52
+
53
+ # Split projections
54
+ i_move, i_merge = torch.chunk(input_connections, 2, dim=-1)
55
+ h_move, h_merge = torch.chunk(hidden_connections, 2, dim=-1)
56
+
57
+ # EMG computation
58
+ # merge_gate = torch.clamp(i_merge, self.clamp_min, self.clamp_max) * torch.sigmoid(torch.clamp(h_merge, self.clamp_min, self.clamp_max))
59
+ merge_gate = torch.clamp(i_merge * torch.sigmoid(h_merge), self.clamp_min, self.clamp_max)
60
+ move_gate = torch.clamp(torch.sigmoid(i_move) * h_move, self.clamp_min, self.clamp_max)
61
+
62
+ if self.use_layer_norm:
63
+ c_prev = self.cell_norm(c_prev)
64
+
65
+ context_gate = torch.tanh(torch.clamp(c_prev + merge_gate, self.clamp_min, self.clamp_max))
66
+
67
+ if self.use_layer_norm:
68
+ context_gate = self.input_norm(context_gate)
69
+
70
+ c_next = context_gate
71
+
72
+ if self.use_layer_norm:
73
+ c_next = self.hidden_norm(c_next)
74
+
75
+ # Apply dropout to output instead of complex variational dropout
76
+ m_next = (1 - move_gate) * merge_gate + move_gate * c_next
77
+ if self.dropout is not None:
78
+ m_next = self.dropout(m_next)
79
+
80
+ return m_next, c_next
81
+
82
+
83
+ class OptimizedEMG(nn.Module):
84
+ """Enhanced EMG with gradient checkpointing and other optimizations"""
85
+ def __init__(self, input_size, hidden_size, num_layers, dropout_rate=0.1,
86
+ use_gradient_checkpointing=False):
87
+ super(OptimizedEMG, self).__init__()
88
+ self.input_size = input_size
89
+ self.hidden_size = hidden_size
90
+ self.num_layers = num_layers
91
+ self.use_gradient_checkpointing = use_gradient_checkpointing
92
+
93
+ self.cells = nn.ModuleList([
94
+ OptimizedEMGCell(
95
+ input_size if i == 0 else hidden_size,
96
+ hidden_size,
97
+ dropout_rate
98
+ ) for i in range(num_layers)
99
+ ])
100
+
101
+ def forward(self, x, hidden=None):
102
+ batch_size, seq_len, _ = x.size()
103
+
104
+ if hidden is None:
105
+ hidden = [(torch.zeros(batch_size, self.hidden_size, device=x.device),
106
+ torch.zeros(batch_size, self.hidden_size, device=x.device))
107
+ for _ in range(self.num_layers)]
108
+
109
+ outputs = []
110
+
111
+ for t in range(seq_len):
112
+ layer_input = x[:, t, :]
113
+
114
+ for layer_idx, cell in enumerate(self.cells):
115
+ m_prev, c_prev = hidden[layer_idx]
116
+
117
+ if self.use_gradient_checkpointing and self.training:
118
+ m_next, c_next = torch.utils.checkpoint.checkpoint(
119
+ cell, layer_input, (m_prev, c_prev), use_reentrant=False
120
+ )
121
+ else:
122
+ m_next, c_next = cell(layer_input, (m_prev, c_prev))
123
+
124
+ hidden[layer_idx] = (m_next, c_next)
125
+ layer_input = m_next
126
+
127
+ outputs.append(layer_input)
128
+
129
+ output = torch.stack(outputs, dim=1)
130
+ return output, hidden
131
+
132
+
133
+ # ===================== HUGGING FACE COMPATIBLE MODEL =====================
134
+
135
+ class EMGConfig(PretrainedConfig):
136
+ """Configuration class for EMG model"""
137
+ model_type = "emg"
138
+
139
+ def __init__(
140
+ self,
141
+ vocab_size=50000,
142
+ embedding_dim=512,
143
+ hidden_dim=512,
144
+ num_layers=2,
145
+ dropout=0.1,
146
+ use_layer_norm=True,
147
+ use_gradient_checkpointing=False,
148
+ tie_word_embeddings=True,
149
+ **kwargs
150
+ ):
151
+ super().__init__(**kwargs)
152
+ self.vocab_size = vocab_size
153
+ self.embedding_dim = embedding_dim
154
+ self.hidden_dim = hidden_dim
155
+ self.num_layers = num_layers
156
+ self.dropout = dropout
157
+ self.use_layer_norm = use_layer_norm
158
+ self.use_gradient_checkpointing = use_gradient_checkpointing
159
+ self.tie_word_embeddings = tie_word_embeddings
160
+
161
+
162
+ class EMGLanguageModel(PreTrainedModel):
163
+ """Hugging Face compatible EMG Language Model"""
164
+ config_class = EMGConfig
165
+
166
+ def __init__(self, config):
167
+ super().__init__(config)
168
+ self.config = config
169
+
170
+ self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
171
+ self.emg = OptimizedEMG(
172
+ config.embedding_dim,
173
+ config.hidden_dim,
174
+ config.num_layers,
175
+ config.dropout,
176
+ config.use_gradient_checkpointing
177
+ )
178
+ self.output_projection = nn.Linear(config.hidden_dim, config.vocab_size)
179
+
180
+ # Tie embedding and output weights if dimensions match
181
+ if config.tie_word_embeddings and config.embedding_dim == config.hidden_dim:
182
+ self.output_projection.weight = self.embedding.weight
183
+
184
+ # Initialize weights
185
+ self.apply(self._init_weights)
186
+
187
+ def _init_weights(self, module):
188
+ """Initialize the weights"""
189
+ if isinstance(module, nn.Linear):
190
+ nn.init.xavier_uniform_(module.weight)
191
+ if module.bias is not None:
192
+ nn.init.zeros_(module.bias)
193
+ elif isinstance(module, nn.Embedding):
194
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
195
+
196
+ def forward(self, input_ids, hidden=None, labels=None, **kwargs):
197
+ embedded = self.embedding(input_ids)
198
+ output, hidden = self.emg(embedded, hidden)
199
+ logits = self.output_projection(output)
200
+
201
+ loss = None
202
+ if labels is not None:
203
+ # Shift so that tokens < n predict n
204
+ shift_logits = logits[..., :-1, :].contiguous()
205
+ shift_labels = labels[..., 1:].contiguous()
206
+
207
+ # Flatten the tokens
208
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
209
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
210
+ shift_labels.view(-1))
211
+
212
+ return {'loss': loss, 'logits': logits, 'hidden_states': hidden}
213
+
214
+ def generate(self, input_ids, max_length=50, temperature=1.0, top_k=50):
215
+ self.eval()
216
+ generated = input_ids
217
+ hidden = None
218
+
219
+ for _ in range(max_length):
220
+ outputs = self.forward(generated[:, -1:], hidden)
221
+ logits = outputs['logits'][:, -1, :] / temperature
222
+
223
+ # Top-k sampling
224
+ top_k_logits, top_k_indices = torch.topk(logits, top_k)
225
+ probs = F.softmax(top_k_logits, dim=-1)
226
+ next_token = top_k_indices.gather(1, torch.multinomial(probs, num_samples=1))
227
+
228
+ generated = torch.cat([generated, next_token], dim=1)
229
+
230
+ return generated
modeling_emg.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Integration for EMG Model and MorPiece Tokenizer
3
+ This file makes your custom model and tokenizer compatible with HuggingFace and lm_eval
4
+ """
5
+
6
+ import json
7
+ import os
8
+ from typing import List, Optional, Union, Dict, Any
9
+ import torch
10
+ import torch.nn as nn
11
+ from transformers import (
12
+ PreTrainedModel,
13
+ PretrainedConfig,
14
+ PreTrainedTokenizer,
15
+ AutoConfig,
16
+ AutoModel,
17
+ AutoTokenizer,
18
+ AutoModelForCausalLM,
19
+ GenerationMixin, # Add this import
20
+ )
21
+ from transformers.modeling_outputs import CausalLMOutputWithPast
22
+
23
+ # Import your existing classes
24
+ from model_eMG_simplified import EMGLanguageModel, EMGConfig, OptimizedEMG, OptimizedEMGCell
25
+ from tokenizer_MorPiece import MorPiece
26
+
27
+
28
+ class MorPieceTokenizer(PreTrainedTokenizer):
29
+ """
30
+ HuggingFace compatible wrapper for MorPiece tokenizer
31
+ """
32
+
33
+ def __init__(self,
34
+ vocab_file=None,
35
+ model_file=None,
36
+ unk_token="<unk>",
37
+ pad_token="<pad>",
38
+ bos_token="<s>",
39
+ eos_token="</s>",
40
+ **kwargs):
41
+
42
+ # Initialize the MorPiece tokenizer
43
+ self.morpiece = MorPiece()
44
+
45
+ # Load from file if provided
46
+ if vocab_file or model_file:
47
+ model_path = vocab_file or model_file
48
+ if os.path.isdir(model_path):
49
+ self.morpiece.from_pretrained(model_path)
50
+ else:
51
+ # Load from JSON file
52
+ with open(model_path, 'r') as f:
53
+ data = json.load(f)
54
+ self.morpiece.roots = data.get('roots', data)
55
+ if 'vocab' in data:
56
+ self.morpiece.vocab_to_id = data['vocab']
57
+ else:
58
+ self.morpiece.build_vocab_lookup()
59
+
60
+ # Get vocabulary
61
+ self.vocab = self.morpiece.get_vocab()
62
+
63
+ # Set special tokens
64
+ super().__init__(
65
+ unk_token=unk_token,
66
+ pad_token=pad_token,
67
+ bos_token=bos_token,
68
+ eos_token=eos_token,
69
+ **kwargs
70
+ )
71
+
72
+ @property
73
+ def vocab_size(self):
74
+ return len(self.vocab)
75
+
76
+ def get_vocab(self):
77
+ return self.vocab.copy()
78
+
79
+ def _tokenize(self, text: str) -> List[str]:
80
+ """Tokenize text into tokens"""
81
+ # For HuggingFace compatibility, we need to return string tokens
82
+ token_ids = self.morpiece.encode(text)
83
+ tokens = self.morpiece.decode(token_ids)
84
+ return tokens
85
+
86
+ def _convert_token_to_id(self, token: str) -> int:
87
+ """Convert token to ID"""
88
+ return self.vocab.get(token, self.vocab.get(self.unk_token, 0))
89
+
90
+ def _convert_id_to_token(self, index: int) -> str:
91
+ """Convert ID to token"""
92
+ for token, idx in self.vocab.items():
93
+ if idx == index:
94
+ return token
95
+ return self.unk_token
96
+
97
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
98
+ """Convert tokens back to string"""
99
+ # Handle special tokens
100
+ text = "".join(tokens)
101
+ # Clean up special tokens for display
102
+ for special_token in [self.pad_token, self.bos_token, self.eos_token]:
103
+ if special_token:
104
+ text = text.replace(special_token, "")
105
+ return text.strip()
106
+
107
+ def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> List[int]:
108
+ """Encode text to token IDs"""
109
+ if add_special_tokens and self.bos_token:
110
+ text = f"{self.bos_token} {text}"
111
+ if add_special_tokens and self.eos_token:
112
+ text = f"{text} {self.eos_token}"
113
+
114
+ return self.morpiece.encode(text)
115
+
116
+ def decode(self, token_ids: List[int], skip_special_tokens: bool = True, **kwargs) -> str:
117
+ """Decode token IDs to text"""
118
+ tokens = []
119
+ for token_id in token_ids:
120
+ token = self._convert_id_to_token(token_id)
121
+ if skip_special_tokens and token in [self.pad_token, self.bos_token, self.eos_token, self.unk_token]:
122
+ continue
123
+ tokens.append(token)
124
+ return self.convert_tokens_to_string(tokens)
125
+
126
+ def save_pretrained(self, save_directory: str, **kwargs):
127
+ """Save tokenizer"""
128
+ os.makedirs(save_directory, exist_ok=True)
129
+
130
+ # Save MorPiece data
131
+ tokenizer_file = os.path.join(save_directory, "tokenizer.json")
132
+ self.morpiece.save(tokenizer_file)
133
+
134
+ # Save tokenizer config
135
+ config = {
136
+ "tokenizer_class": "MorPieceTokenizer",
137
+ "unk_token": self.unk_token,
138
+ "pad_token": self.pad_token,
139
+ "bos_token": self.bos_token,
140
+ "eos_token": self.eos_token,
141
+ }
142
+
143
+ config_file = os.path.join(save_directory, "tokenizer_config.json")
144
+ with open(config_file, 'w') as f:
145
+ json.dump(config, f, indent=2)
146
+
147
+ @classmethod
148
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
149
+ """Load tokenizer from pretrained"""
150
+ return cls(vocab_file=pretrained_model_name_or_path, **kwargs)
151
+
152
+
153
+ class EMGForCausalLM(EMGLanguageModel, GenerationMixin):
154
+ """
155
+ Enhanced EMG model with better HuggingFace compatibility for lm_eval
156
+ Inherits from GenerationMixin to fix the warning
157
+ """
158
+
159
+ def __init__(self, config):
160
+ # Initialize EMGLanguageModel first
161
+ EMGLanguageModel.__init__(self, config)
162
+ # Then initialize GenerationMixin
163
+ GenerationMixin.__init__(self)
164
+ self.config = config
165
+
166
+ def forward(
167
+ self,
168
+ input_ids: torch.Tensor,
169
+ attention_mask: Optional[torch.Tensor] = None,
170
+ labels: Optional[torch.Tensor] = None,
171
+ past_key_values: Optional[tuple] = None,
172
+ use_cache: Optional[bool] = None,
173
+ **kwargs
174
+ ) -> CausalLMOutputWithPast:
175
+ """
176
+ Forward pass with HuggingFace compatible output format
177
+ """
178
+ # Get embeddings
179
+ embedded = self.embedding(input_ids)
180
+
181
+ # Pass through EMG layers
182
+ output, hidden = self.emg(embedded, past_key_values)
183
+
184
+ # Get logits
185
+ logits = self.output_projection(output)
186
+
187
+ loss = None
188
+ if labels is not None:
189
+ # Shift so that tokens < n predict n
190
+ shift_logits = logits[..., :-1, :].contiguous()
191
+ shift_labels = labels[..., 1:].contiguous()
192
+
193
+ # Flatten the tokens
194
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
195
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
196
+ shift_labels.view(-1))
197
+
198
+ return CausalLMOutputWithPast(
199
+ loss=loss,
200
+ logits=logits,
201
+ past_key_values=hidden if use_cache else None,
202
+ hidden_states=output,
203
+ )
204
+
205
+ def prepare_inputs_for_generation(
206
+ self,
207
+ input_ids: torch.Tensor,
208
+ past_key_values=None,
209
+ attention_mask=None,
210
+ **kwargs
211
+ ):
212
+ """Prepare inputs for generation"""
213
+ return {
214
+ "input_ids": input_ids,
215
+ "past_key_values": past_key_values,
216
+ "attention_mask": attention_mask,
217
+ }
218
+
219
+ def _reorder_cache(self, past_key_values, beam_idx):
220
+ """Reorder cache for beam search"""
221
+ if past_key_values is None:
222
+ return None
223
+
224
+ reordered_cache = []
225
+ for layer_cache in past_key_values:
226
+ if isinstance(layer_cache, tuple):
227
+ reordered_cache.append(tuple(
228
+ cache.index_select(0, beam_idx) for cache in layer_cache
229
+ ))
230
+ else:
231
+ reordered_cache.append(layer_cache.index_select(0, beam_idx))
232
+ return tuple(reordered_cache)
233
+
234
+
235
+ # Register the custom classes with transformers
236
+ def register_emg_model():
237
+ """Register EMG model and tokenizer with transformers"""
238
+
239
+ # Register config
240
+ AutoConfig.register("emg", EMGConfig)
241
+
242
+ # Register model
243
+ AutoModel.register(EMGConfig, EMGLanguageModel)
244
+ AutoModelForCausalLM.register(EMGConfig, EMGForCausalLM)
245
+
246
+ # Register tokenizer
247
+ AutoTokenizer.register(EMGConfig, MorPieceTokenizer)
248
+
249
+ print("EMG model and MorPiece tokenizer registered with transformers!")
250
+
251
+
252
+ def load_emg_model_and_tokenizer(model_path: str):
253
+ """
254
+ Load EMG model and MorPiece tokenizer from saved directory
255
+
256
+ Args:
257
+ model_path: Path to the saved model directory
258
+
259
+ Returns:
260
+ tuple: (model, tokenizer)
261
+ """
262
+ # Register classes first
263
+ register_emg_model()
264
+
265
+ # Load model
266
+ config = EMGConfig.from_pretrained(model_path)
267
+ model = EMGForCausalLM.from_pretrained(model_path, config=config)
268
+
269
+ # Load tokenizer
270
+ tokenizer = MorPieceTokenizer.from_pretrained(model_path)
271
+
272
+ # Set pad token id in model config if not set
273
+ if not hasattr(config, 'pad_token_id') or config.pad_token_id is None:
274
+ config.pad_token_id = tokenizer.pad_token_id
275
+ model.config.pad_token_id = tokenizer.pad_token_id
276
+
277
+ return model, tokenizer
278
+
279
+
280
+ def test_model_and_tokenizer(model_path: str):
281
+ """Test the loaded model and tokenizer"""
282
+ model, tokenizer = load_emg_model_and_tokenizer(model_path)
283
+
284
+ # Test encoding/decoding
285
+ test_text = "Hello world, this is a test."
286
+ print(f"Original text: {test_text}")
287
+
288
+ # Encode
289
+ encoded = tokenizer.encode(test_text)
290
+ print(f"Encoded: {encoded}")
291
+
292
+ # Decode
293
+ decoded = tokenizer.decode(encoded, skip_special_tokens=True)
294
+ print(f"Decoded: {decoded}")
295
+
296
+ # Test model forward pass
297
+ input_ids = torch.tensor([encoded])
298
+ with torch.no_grad():
299
+ outputs = model(input_ids)
300
+ print(f"Model output shape: {outputs.logits.shape}")
301
+ print(f"Model output type: {type(outputs)}")
302
+
303
+ print("Model and tokenizer are working correctly!")
304
+ return model, tokenizer
305
+
306
+
307
+ if __name__ == "__main__":
308
+ # Example usage
309
+ model_path = "path/to/your/saved/model" # Replace with your model path
310
+
311
+ # Register the classes
312
+ register_emg_model()
313
+
314
+ # Test loading
315
+ try:
316
+ model, tokenizer = test_model_and_tokenizer(model_path)
317
+ print("✅ Model and tokenizer loaded successfully!")
318
+ except Exception as e:
319
+ print(f"❌ Error loading model: {e}")
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5f54b4392c29c455472735d1de207e490f1ef9789ac39df15a50a5117feba81
3
+ size 163016093
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch>=1.9.0
2
+ transformers>=4.20.0
3
+ numpy
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_MorPiece.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from math import log
4
+
5
+
6
+ class MorPiece:
7
+ def __init__(self, vocab_size=30000, min_frequency=2, cutoff=8, bf=10, special_tokens=None):
8
+ self.tokenization_to_print = "TP left-right \t BF right-left \t TP right-left \t BP right-left\n" # for debugging only
9
+ if special_tokens is None:
10
+ special_tokens = ['<unk>', '<pad>', '<s>', '</s>']
11
+ self.special_tokens = special_tokens
12
+ self.reserved_keys = {'[RSX]', '##', 'IDX', '++'}
13
+ self.vocab_size = vocab_size
14
+ self.min_frequency = min_frequency
15
+ self.bf = bf
16
+ self.roots = {'[RSX]': {}, '++': {}}
17
+ self.roots_unoptimized = {}
18
+ self.infls = {}
19
+ self.types = {}
20
+ self.last_item_in_trie = {}
21
+ self.idx = 0
22
+ self.tokens = []
23
+ self.suffixes = []
24
+ self.tokens_bf = []
25
+ self.suffixes_bf = []
26
+ self.prefix = ""
27
+ self.n_prefix = 0
28
+ self.n_suffix = 0
29
+ self.tokenized_words = []
30
+ self.tokenized_word_longest = ""
31
+ self.tokenized_word_idx_longest = ""
32
+ self.cutoff = cutoff # ln(8) is > 2, so, non-branching paths will be ignored
33
+ self.num_tokens_in_corpus = 0
34
+ self.num_chars_in_corpus = 0
35
+ self.num_chars_in_trie = 0
36
+ self.num_chars_in_optimized_trie = 0
37
+ self.set_special_tokens(self.special_tokens)
38
+
39
+ def train(self, corpus: str): # create the vocabulary
40
+ words = corpus.split()
41
+ print("MorPiece tokenizer training: processing words...")
42
+ for word in words:
43
+ word_alpha = ''.join([char for char in word if char.isalpha() or char == "'"])
44
+ if not word_alpha:
45
+ word = ''.join([char for char in word])
46
+ else:
47
+ word = word_alpha
48
+ if word:
49
+ self.build_trie(word, self.roots_unoptimized) # create roots trie
50
+ self.build_trie(word[::-1], self.infls) # create inflections trie
51
+ if word not in self.types: # count tokens and chars in corpus
52
+ self.types[word] = 1
53
+ else:
54
+ self.types[word] += 1
55
+ self.num_tokens_in_corpus += 1
56
+ self.num_chars_in_corpus += len(word)
57
+ self.types = dict(sorted(self.types.items(), key=lambda item: item[1], reverse=True))
58
+ sort_trie_by_freq(self.roots_unoptimized)
59
+ sort_trie_by_freq(self.infls)
60
+
61
+ print("MorPiece tokenizer training: trie optimization...")
62
+ self.optimize(self.types)
63
+
64
+ print(f"Built final vocabulary with {self.get_vocab_size()} tokens")
65
+ print(f"Most common tokens: {list(self.types.items())[:20]}")
66
+
67
+ def build_trie(self, wordpiece, root): # build the trie and register # of traversals in '##'
68
+ if wordpiece[0] in root:
69
+ root[wordpiece[0]]['##'] += 1
70
+ self.num_chars_in_trie += 1
71
+ if len(wordpiece) > 1:
72
+ self.build_trie(wordpiece[1:], root[wordpiece[0]])
73
+ else:
74
+ if 'END' not in root[wordpiece[0]]:
75
+ root[wordpiece[0]]['END'] = None
76
+ else:
77
+ root[wordpiece[0]] = {}
78
+ root[wordpiece[0]]['##'] = 1
79
+ if len(wordpiece) > 1:
80
+ self.build_trie(wordpiece[1:], root[wordpiece[0]])
81
+
82
+ def set_special_tokens(self, list):
83
+ for item in list:
84
+ if item not in self.roots['[RSX]'].keys():
85
+ self.roots['[RSX]'][item] = {'IDX': None}
86
+ self.roots['[RSX]'][item]['IDX'] = self.idx
87
+ self.idx += 1
88
+
89
+ # assign idx based on word freq and add potential inflection links in the root trie, remove frequency at the end
90
+ def optimize(self, words):
91
+ for word, freq in words.items():
92
+ if freq >= self.min_frequency and self.idx <= self.vocab_size:
93
+ self.tokens = []
94
+ self.suffixes = []
95
+ self.tokens_bf = []
96
+ self.suffixes_bf = []
97
+ self.tokens.append(word[0])
98
+ self.suffixes.append(word[len(word) - 1])
99
+ self.split_prefix(word, self.roots_unoptimized)
100
+ if len(self.tokens) > 1:
101
+ self.split_suffix(word[::-1], self.infls)
102
+ self.suffixes = [word[::-1] for word in self.suffixes][::-1]
103
+ self.tokenization_to_print += str(self.tokens) + '\t' + str(self.tokens_bf) + '\t' + str(
104
+ self.suffixes) + '\t' + str(self.suffixes_bf) + '\n' # for debugging only
105
+ for i in range(0,
106
+ len(self.tokens)): # esperimenti: usare solo self.suffixes o self.tokens (prefissi)
107
+ if i == 0:
108
+ self.last_item_in_trie = self.roots
109
+ self.add_items_to_trie(
110
+ self.tokens[0]) # esperimenti: usare solo self.suffixes o self.tokens (prefissi)
111
+ else:
112
+ self.last_item_in_trie = self.roots['++']
113
+ self.add_items_to_trie(
114
+ self.tokens[i]) # esperimenti: usare solo self.suffixes o self.tokens (prefissi)
115
+ if 'IDX' not in self.last_item_in_trie:
116
+ self.last_item_in_trie['IDX'] = self.idx
117
+ self.idx += 1
118
+ else:
119
+ self.last_item_in_trie = self.roots
120
+ self.add_items_to_trie(word)
121
+ if 'IDX' not in self.last_item_in_trie:
122
+ self.last_item_in_trie['IDX'] = self.idx
123
+ self.idx += 1
124
+
125
+ self.build_vocab_lookup()
126
+
127
+ def build_vocab_lookup(self):
128
+ self.vocab_to_id = {}
129
+
130
+ def traverse(trie, path):
131
+ for k, v in trie.items():
132
+ if k == 'IDX':
133
+ token = ''.join(path)
134
+ self.vocab_to_id[token] = v
135
+ elif isinstance(v, dict):
136
+ traverse(v, path + [k])
137
+
138
+ traverse(self.roots, [])
139
+
140
+ def encode(self, sentence: str):
141
+ self.tokenized_words = []
142
+ words = sentence.strip().split()
143
+ token_ids = []
144
+ for word in words:
145
+ if word in self.roots['[RSX]']:
146
+ token_ids.append(self.roots['[RSX]'][word]['IDX'])
147
+ else:
148
+ self.tokenized_word_longest = ""
149
+ self.tokenized_word_idx_longest = None
150
+ self.retrieve(word, self.roots)
151
+ if self.tokenized_word_idx_longest is not None:
152
+ token_ids.append(self.tokenized_word_idx_longest)
153
+ else:
154
+ token_ids.append(self.roots['[RSX]']['<unk>']['IDX'])
155
+ return token_ids
156
+
157
+ def decode(self, sentence_idxs):
158
+ tokens = []
159
+ for idx in sentence_idxs:
160
+ keys_path = find_idx_path(self.roots, idx)
161
+ if keys_path:
162
+ token = "".join(keys_path)
163
+ if token.startswith('[RSX]'):
164
+ token = token[5:]
165
+ tokens.append(token)
166
+ return tokens
167
+
168
+ def retrieve(self, word, trie):
169
+ self.longest_match_in_trie(word, trie)
170
+ if self.tokenized_word_longest:
171
+ self.tokenized_words.append([self.tokenized_word_longest, self.tokenized_word_idx_longest])
172
+ else:
173
+ self.tokenized_words.append(['<unk>', self.roots['[RSX]']['<unk>']['IDX']])
174
+
175
+ def longest_match_in_trie(self, string, trie):
176
+ if string[0] in trie:
177
+ self.tokenized_word_longest += string[0]
178
+ if 'IDX' in trie[string[0]]:
179
+ self.tokenized_word_idx_longest = trie[string[0]]['IDX']
180
+ if len(string) > 1:
181
+ self.longest_match_in_trie(string[1:], trie[string[0]])
182
+ else:
183
+ # print(string[0], self.tokenized_word_longest)
184
+ if string[0] in self.roots['++'] and self.tokenized_word_idx_longest:
185
+ self.tokenized_words.append([self.tokenized_word_longest + '++', self.tokenized_word_idx_longest])
186
+ self.tokenized_word_longest = '++'
187
+ self.tokenized_word_idx_longest = 0
188
+ self.longest_match_in_trie(string, self.roots['++'])
189
+ else:
190
+ self.tokenized_words.append(['<unk>', self.roots['[RSX]']['<unk>']['IDX']])
191
+ self.tokenized_word_longest = None
192
+
193
+ def split_prefix(self, word, trie):
194
+ l = len(word)
195
+ if l > 1:
196
+ self.get_pair_in_trie(word[0], word[1], trie)
197
+ if self.check_tp(self.n_prefix, self.n_suffix) and self.get_bf(trie[word[0]]) <= self.bf:
198
+ self.tokens.append(word[1])
199
+ self.tokens_bf.append(word[0] + str(self.get_bf(trie[word[0]])))
200
+ else:
201
+ self.tokens[len(self.tokens) - 1] = self.tokens[len(self.tokens) - 1] + word[1]
202
+ if l > 2:
203
+ self.split_prefix(word[1:], trie[word[0]])
204
+
205
+ def split_suffix(self, word, trie):
206
+ l = len(word)
207
+ if l > 1:
208
+ self.get_pair_in_trie(word[0], word[1], trie)
209
+ if self.check_tp(self.n_prefix, self.n_suffix) and self.get_bf(trie[word[0]]) <= self.bf: # verify if the
210
+ self.suffixes.append(word[1])
211
+ self.suffixes_bf.append(word[0] + str(self.get_bf(trie[word[0]])))
212
+ else:
213
+ self.suffixes[len(self.suffixes) - 1] = self.suffixes[len(self.suffixes) - 1] + word[1]
214
+ if l > 2:
215
+ if word[0] in trie.keys():
216
+ self.split_suffix(word[1:], trie[word[0]])
217
+
218
+ def get_pair_in_trie(self, prefix, suffix, trie):
219
+ self.n_prefix = 0
220
+ self.n_suffix = 0
221
+ if prefix in trie:
222
+ if suffix in trie[prefix]:
223
+ self.n_prefix = trie[prefix]["##"]
224
+ self.n_suffix = trie[prefix][suffix]["##"]
225
+
226
+ def check_tp(self, m, d): # verify if Tolerance Principle applies between m(other) and d(aughter) nodes
227
+ if not m > 1:
228
+ return False
229
+ else:
230
+ tp = m / log(m)
231
+ if self.cutoff <= m != d > tp:
232
+ return True
233
+ else:
234
+ return False
235
+
236
+ def get_bf(self, m): # return the branching factor of the mother node
237
+ keys = m.keys()
238
+ n_keys = len(keys)
239
+ for k in keys:
240
+ if k in self.special_tokens:
241
+ n_keys -= 1
242
+ return n_keys
243
+
244
+ def add_items_to_trie(self, items):
245
+ for item in items:
246
+ self.add_item_to_trie(item)
247
+
248
+ def add_item_to_trie(self, item):
249
+ if item not in self.last_item_in_trie:
250
+ self.last_item_in_trie[item] = {}
251
+ self.last_item_in_trie = self.last_item_in_trie[item]
252
+
253
+ def pad_sentence(sentence, l):
254
+ """
255
+ Pads the given sentence with "[pad]" tokens at the beginning to reach the desired length.
256
+
257
+ Parameters:
258
+ - sentence (str): The original sentence to be padded.
259
+ - l (int): The desired total number of tokens in the sentence after padding.
260
+
261
+ Returns:
262
+ - str: The padded sentence.
263
+ """
264
+ words = sentence.split()
265
+ n_pad = max(l - len(words), 0) # Ensure n_pad is not negative
266
+ pad_tokens = ["[pad]"] * n_pad
267
+ padded_sentence = ' '.join(pad_tokens + words)
268
+ return padded_sentence
269
+
270
+ def get_num_chars_in_trie(self):
271
+ return self.num_chars_in_trie
272
+
273
+ def get_num_chars_in_corpus(self):
274
+ return self.num_chars_in_corpus
275
+
276
+ def get_vocab_size(self) -> int:
277
+ return self.idx
278
+
279
+ def get_vocab(self):
280
+ return self.vocab_to_id.copy()
281
+
282
+ def get_num_tokens_in_corpus(self):
283
+ return self.num_tokens_in_corpus
284
+
285
+ def get_num_types_in_corpus(self):
286
+ return len(self.types)
287
+
288
+ def get_compression_ratio(self):
289
+ return round(self.num_chars_in_trie / self.num_chars_in_corpus, 3)
290
+
291
+ def get_ttr(self):
292
+ return round(len(self.types) / self.num_tokens_in_corpus, 3)
293
+
294
+ def save(self, save_file):
295
+ self.build_vocab_lookup()
296
+ with open(save_file, 'w') as f:
297
+ json.dump({
298
+ 'roots': self.roots,
299
+ 'vocab': self.vocab_to_id
300
+ }, f, indent=2)
301
+
302
+ def from_pretrained(self, load_file):
303
+ with open(load_file + '/tokenizer.json', 'r') as f:
304
+ data = json.load(f)
305
+
306
+ # Backward compatibility: if old format, data is just roots
307
+ if isinstance(data, dict) and 'roots' in data:
308
+ self.roots = data['roots']
309
+ self.vocab_to_id = data.get('vocab', {}) # fallback to empty dict if missing
310
+ else:
311
+ # Old format support (e.g., tokenizer.json only had roots)
312
+ self.roots = data
313
+ self.vocab_to_id = {}
314
+
315
+ # Ensure [RSX] exists
316
+ if '[RSX]' not in self.roots:
317
+ raise ValueError("Invalid tokenizer format: Missing [RSX] root node.")
318
+
319
+ def save_types(self, file):
320
+ with open(file, 'w') as f:
321
+ json.dump(self.types, f, indent=2)
322
+
323
+
324
+ def sort_trie_by_freq(d):
325
+ if not isinstance(d, dict):
326
+ return d
327
+ # Sort the dictionary items by the value of the nested key '##'
328
+ sorted_items = sorted(
329
+ d.items(),
330
+ key=lambda item: item[1].get('##', float('-inf')) if isinstance(item[1], dict) else float('-inf'),
331
+ reverse=True
332
+ )
333
+ # Clear the dictionary and update with sorted items
334
+ d.clear()
335
+ for k, v in sorted_items:
336
+ d[k] = sort_trie_by_freq(v)
337
+ return d
338
+
339
+
340
+ def find_idx_path(d, target_value, path=None):
341
+ if path is None:
342
+ path = []
343
+ for key, value in d.items():
344
+ if key == 'IDX' and value == target_value:
345
+ return path
346
+ elif isinstance(value, dict):
347
+ result = find_idx_path(value, target_value, path + [key])
348
+ if result is not None:
349
+ return result
350
+ return None
tokenizer_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "MorPieceTokenizer",
3
+ "unk_token": "<unk>",
4
+ "pad_token": "<pad>",
5
+ "bos_token": "<s>",
6
+ "eos_token": "</s>"
7
+ }