|
|
""" |
|
|
HuggingFace Integration for EMG Model and MorPiece Tokenizer |
|
|
This file makes your custom model and tokenizer compatible with HuggingFace and lm_eval |
|
|
""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
from typing import List, Optional, Union, Dict, Any |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import ( |
|
|
PreTrainedModel, |
|
|
PretrainedConfig, |
|
|
PreTrainedTokenizer, |
|
|
AutoConfig, |
|
|
AutoModel, |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
GenerationMixin, |
|
|
) |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
|
|
|
|
from model_eMG_simplified import EMGLanguageModel, EMGConfig, OptimizedEMG, OptimizedEMGCell |
|
|
from tokenizer_MorPiece import MorPiece |
|
|
|
|
|
|
|
|
class MorPieceTokenizer(PreTrainedTokenizer): |
|
|
""" |
|
|
HuggingFace compatible wrapper for MorPiece tokenizer |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
vocab_file=None, |
|
|
model_file=None, |
|
|
unk_token="<unk>", |
|
|
pad_token="<pad>", |
|
|
bos_token="<s>", |
|
|
eos_token="</s>", |
|
|
**kwargs): |
|
|
|
|
|
|
|
|
self.morpiece = MorPiece() |
|
|
|
|
|
|
|
|
if vocab_file or model_file: |
|
|
model_path = vocab_file or model_file |
|
|
if os.path.isdir(model_path): |
|
|
self.morpiece.from_pretrained(model_path) |
|
|
else: |
|
|
|
|
|
with open(model_path, 'r') as f: |
|
|
data = json.load(f) |
|
|
self.morpiece.roots = data.get('roots', data) |
|
|
if 'vocab' in data: |
|
|
self.morpiece.vocab_to_id = data['vocab'] |
|
|
else: |
|
|
self.morpiece.build_vocab_lookup() |
|
|
|
|
|
|
|
|
self.vocab = self.morpiece.get_vocab() |
|
|
|
|
|
|
|
|
super().__init__( |
|
|
unk_token=unk_token, |
|
|
pad_token=pad_token, |
|
|
bos_token=bos_token, |
|
|
eos_token=eos_token, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
@property |
|
|
def vocab_size(self): |
|
|
return len(self.vocab) |
|
|
|
|
|
def get_vocab(self): |
|
|
return self.vocab.copy() |
|
|
|
|
|
def _tokenize(self, text: str) -> List[str]: |
|
|
"""Tokenize text into tokens""" |
|
|
|
|
|
token_ids = self.morpiece.encode(text) |
|
|
tokens = self.morpiece.decode(token_ids) |
|
|
return tokens |
|
|
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
|
"""Convert token to ID""" |
|
|
return self.vocab.get(token, self.vocab.get(self.unk_token, 0)) |
|
|
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
|
"""Convert ID to token""" |
|
|
for token, idx in self.vocab.items(): |
|
|
if idx == index: |
|
|
return token |
|
|
return self.unk_token |
|
|
|
|
|
def convert_tokens_to_string(self, tokens: List[str]) -> str: |
|
|
"""Convert tokens back to string""" |
|
|
|
|
|
text = "".join(tokens) |
|
|
|
|
|
for special_token in [self.pad_token, self.bos_token, self.eos_token]: |
|
|
if special_token: |
|
|
text = text.replace(special_token, "") |
|
|
return text.strip() |
|
|
|
|
|
def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> List[int]: |
|
|
"""Encode text to token IDs""" |
|
|
if add_special_tokens and self.bos_token: |
|
|
text = f"{self.bos_token} {text}" |
|
|
if add_special_tokens and self.eos_token: |
|
|
text = f"{text} {self.eos_token}" |
|
|
|
|
|
return self.morpiece.encode(text) |
|
|
|
|
|
def decode(self, token_ids: List[int], skip_special_tokens: bool = True, **kwargs) -> str: |
|
|
"""Decode token IDs to text""" |
|
|
tokens = [] |
|
|
for token_id in token_ids: |
|
|
token = self._convert_id_to_token(token_id) |
|
|
if skip_special_tokens and token in [self.pad_token, self.bos_token, self.eos_token, self.unk_token]: |
|
|
continue |
|
|
tokens.append(token) |
|
|
return self.convert_tokens_to_string(tokens) |
|
|
|
|
|
def save_pretrained(self, save_directory: str, **kwargs): |
|
|
"""Save tokenizer""" |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
|
|
|
tokenizer_file = os.path.join(save_directory, "tokenizer.json") |
|
|
self.morpiece.save(tokenizer_file) |
|
|
|
|
|
|
|
|
config = { |
|
|
"tokenizer_class": "MorPieceTokenizer", |
|
|
"unk_token": self.unk_token, |
|
|
"pad_token": self.pad_token, |
|
|
"bos_token": self.bos_token, |
|
|
"eos_token": self.eos_token, |
|
|
} |
|
|
|
|
|
config_file = os.path.join(save_directory, "tokenizer_config.json") |
|
|
with open(config_file, 'w') as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
|
|
"""Load tokenizer from pretrained""" |
|
|
return cls(vocab_file=pretrained_model_name_or_path, **kwargs) |
|
|
|
|
|
|
|
|
class EMGForCausalLM(EMGLanguageModel, GenerationMixin): |
|
|
""" |
|
|
Enhanced EMG model with better HuggingFace compatibility for lm_eval |
|
|
Inherits from GenerationMixin to fix the warning |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
|
|
|
EMGLanguageModel.__init__(self, config) |
|
|
|
|
|
GenerationMixin.__init__(self) |
|
|
self.config = config |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[tuple] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
**kwargs |
|
|
) -> CausalLMOutputWithPast: |
|
|
""" |
|
|
Forward pass with HuggingFace compatible output format |
|
|
""" |
|
|
|
|
|
embedded = self.embedding(input_ids) |
|
|
|
|
|
|
|
|
output, hidden = self.emg(embedded, past_key_values) |
|
|
|
|
|
|
|
|
logits = self.output_projection(output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1)) |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=hidden if use_cache else None, |
|
|
hidden_states=output, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
past_key_values=None, |
|
|
attention_mask=None, |
|
|
**kwargs |
|
|
): |
|
|
"""Prepare inputs for generation""" |
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"past_key_values": past_key_values, |
|
|
"attention_mask": attention_mask, |
|
|
} |
|
|
|
|
|
def _reorder_cache(self, past_key_values, beam_idx): |
|
|
"""Reorder cache for beam search""" |
|
|
if past_key_values is None: |
|
|
return None |
|
|
|
|
|
reordered_cache = [] |
|
|
for layer_cache in past_key_values: |
|
|
if isinstance(layer_cache, tuple): |
|
|
reordered_cache.append(tuple( |
|
|
cache.index_select(0, beam_idx) for cache in layer_cache |
|
|
)) |
|
|
else: |
|
|
reordered_cache.append(layer_cache.index_select(0, beam_idx)) |
|
|
return tuple(reordered_cache) |
|
|
|
|
|
|
|
|
|
|
|
def register_emg_model(): |
|
|
"""Register EMG model and tokenizer with transformers""" |
|
|
|
|
|
|
|
|
AutoConfig.register("emg", EMGConfig) |
|
|
|
|
|
|
|
|
AutoModel.register(EMGConfig, EMGLanguageModel) |
|
|
AutoModelForCausalLM.register(EMGConfig, EMGForCausalLM) |
|
|
|
|
|
|
|
|
AutoTokenizer.register(EMGConfig, MorPieceTokenizer) |
|
|
|
|
|
print("EMG model and MorPiece tokenizer registered with transformers!") |
|
|
|
|
|
|
|
|
def load_emg_model_and_tokenizer(model_path: str): |
|
|
""" |
|
|
Load EMG model and MorPiece tokenizer from saved directory |
|
|
|
|
|
Args: |
|
|
model_path: Path to the saved model directory |
|
|
|
|
|
Returns: |
|
|
tuple: (model, tokenizer) |
|
|
""" |
|
|
|
|
|
register_emg_model() |
|
|
|
|
|
|
|
|
config = EMGConfig.from_pretrained(model_path) |
|
|
model = EMGForCausalLM.from_pretrained(model_path, config=config) |
|
|
|
|
|
|
|
|
tokenizer = MorPieceTokenizer.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
if not hasattr(config, 'pad_token_id') or config.pad_token_id is None: |
|
|
config.pad_token_id = tokenizer.pad_token_id |
|
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def test_model_and_tokenizer(model_path: str): |
|
|
"""Test the loaded model and tokenizer""" |
|
|
model, tokenizer = load_emg_model_and_tokenizer(model_path) |
|
|
|
|
|
|
|
|
test_text = "Hello world, this is a test." |
|
|
print(f"Original text: {test_text}") |
|
|
|
|
|
|
|
|
encoded = tokenizer.encode(test_text) |
|
|
print(f"Encoded: {encoded}") |
|
|
|
|
|
|
|
|
decoded = tokenizer.decode(encoded, skip_special_tokens=True) |
|
|
print(f"Decoded: {decoded}") |
|
|
|
|
|
|
|
|
input_ids = torch.tensor([encoded]) |
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids) |
|
|
print(f"Model output shape: {outputs.logits.shape}") |
|
|
print(f"Model output type: {type(outputs)}") |
|
|
|
|
|
print("Model and tokenizer are working correctly!") |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
model_path = "path/to/your/saved/model" |
|
|
|
|
|
|
|
|
register_emg_model() |
|
|
|
|
|
|
|
|
try: |
|
|
model, tokenizer = test_model_and_tokenizer(model_path) |
|
|
print("✅ Model and tokenizer loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"❌ Error loading model: {e}") |
|
|
|