""" HuggingFace Integration for EMG Model and MorPiece Tokenizer This file makes your custom model and tokenizer compatible with HuggingFace and lm_eval """ import json import os from typing import List, Optional, Union, Dict, Any import torch import torch.nn as nn from transformers import ( PreTrainedModel, PretrainedConfig, PreTrainedTokenizer, AutoConfig, AutoModel, AutoTokenizer, AutoModelForCausalLM, GenerationMixin, # Add this import ) from transformers.modeling_outputs import CausalLMOutputWithPast # Import your existing classes from model_eMG_simplified import EMGLanguageModel, EMGConfig, OptimizedEMG, OptimizedEMGCell from tokenizer_MorPiece import MorPiece class MorPieceTokenizer(PreTrainedTokenizer): """ HuggingFace compatible wrapper for MorPiece tokenizer """ def __init__(self, vocab_file=None, model_file=None, unk_token="", pad_token="", bos_token="", eos_token="", **kwargs): # Initialize the MorPiece tokenizer self.morpiece = MorPiece() # Load from file if provided if vocab_file or model_file: model_path = vocab_file or model_file if os.path.isdir(model_path): self.morpiece.from_pretrained(model_path) else: # Load from JSON file with open(model_path, 'r') as f: data = json.load(f) self.morpiece.roots = data.get('roots', data) if 'vocab' in data: self.morpiece.vocab_to_id = data['vocab'] else: self.morpiece.build_vocab_lookup() # Get vocabulary self.vocab = self.morpiece.get_vocab() # Set special tokens super().__init__( unk_token=unk_token, pad_token=pad_token, bos_token=bos_token, eos_token=eos_token, **kwargs ) @property def vocab_size(self): return len(self.vocab) def get_vocab(self): return self.vocab.copy() def _tokenize(self, text: str) -> List[str]: """Tokenize text into tokens""" # For HuggingFace compatibility, we need to return string tokens token_ids = self.morpiece.encode(text) tokens = self.morpiece.decode(token_ids) return tokens def _convert_token_to_id(self, token: str) -> int: """Convert token to ID""" return self.vocab.get(token, self.vocab.get(self.unk_token, 0)) def _convert_id_to_token(self, index: int) -> str: """Convert ID to token""" for token, idx in self.vocab.items(): if idx == index: return token return self.unk_token def convert_tokens_to_string(self, tokens: List[str]) -> str: """Convert tokens back to string""" # Handle special tokens text = "".join(tokens) # Clean up special tokens for display for special_token in [self.pad_token, self.bos_token, self.eos_token]: if special_token: text = text.replace(special_token, "") return text.strip() def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> List[int]: """Encode text to token IDs""" if add_special_tokens and self.bos_token: text = f"{self.bos_token} {text}" if add_special_tokens and self.eos_token: text = f"{text} {self.eos_token}" return self.morpiece.encode(text) def decode(self, token_ids: List[int], skip_special_tokens: bool = True, **kwargs) -> str: """Decode token IDs to text""" tokens = [] for token_id in token_ids: token = self._convert_id_to_token(token_id) if skip_special_tokens and token in [self.pad_token, self.bos_token, self.eos_token, self.unk_token]: continue tokens.append(token) return self.convert_tokens_to_string(tokens) def save_pretrained(self, save_directory: str, **kwargs): """Save tokenizer""" os.makedirs(save_directory, exist_ok=True) # Save MorPiece data tokenizer_file = os.path.join(save_directory, "tokenizer.json") self.morpiece.save(tokenizer_file) # Save tokenizer config config = { "tokenizer_class": "MorPieceTokenizer", "unk_token": self.unk_token, "pad_token": self.pad_token, "bos_token": self.bos_token, "eos_token": self.eos_token, } config_file = os.path.join(save_directory, "tokenizer_config.json") with open(config_file, 'w') as f: json.dump(config, f, indent=2) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): """Load tokenizer from pretrained""" return cls(vocab_file=pretrained_model_name_or_path, **kwargs) class EMGForCausalLM(EMGLanguageModel, GenerationMixin): """ Enhanced EMG model with better HuggingFace compatibility for lm_eval Inherits from GenerationMixin to fix the warning """ def __init__(self, config): # Initialize EMGLanguageModel first EMGLanguageModel.__init__(self, config) # Then initialize GenerationMixin GenerationMixin.__init__(self) self.config = config def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, past_key_values: Optional[tuple] = None, use_cache: Optional[bool] = None, **kwargs ) -> CausalLMOutputWithPast: """ Forward pass with HuggingFace compatible output format """ # Get embeddings embedded = self.embedding(input_ids) # Pass through EMG layers output, hidden = self.emg(embedded, past_key_values) # Get logits logits = self.output_projection(output) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=hidden if use_cache else None, hidden_states=output, ) def prepare_inputs_for_generation( self, input_ids: torch.Tensor, past_key_values=None, attention_mask=None, **kwargs ): """Prepare inputs for generation""" return { "input_ids": input_ids, "past_key_values": past_key_values, "attention_mask": attention_mask, } def _reorder_cache(self, past_key_values, beam_idx): """Reorder cache for beam search""" if past_key_values is None: return None reordered_cache = [] for layer_cache in past_key_values: if isinstance(layer_cache, tuple): reordered_cache.append(tuple( cache.index_select(0, beam_idx) for cache in layer_cache )) else: reordered_cache.append(layer_cache.index_select(0, beam_idx)) return tuple(reordered_cache) # Register the custom classes with transformers def register_emg_model(): """Register EMG model and tokenizer with transformers""" # Register config AutoConfig.register("emg", EMGConfig) # Register model AutoModel.register(EMGConfig, EMGLanguageModel) AutoModelForCausalLM.register(EMGConfig, EMGForCausalLM) # Register tokenizer AutoTokenizer.register(EMGConfig, MorPieceTokenizer) print("EMG model and MorPiece tokenizer registered with transformers!") def load_emg_model_and_tokenizer(model_path: str): """ Load EMG model and MorPiece tokenizer from saved directory Args: model_path: Path to the saved model directory Returns: tuple: (model, tokenizer) """ # Register classes first register_emg_model() # Load model config = EMGConfig.from_pretrained(model_path) model = EMGForCausalLM.from_pretrained(model_path, config=config) # Load tokenizer tokenizer = MorPieceTokenizer.from_pretrained(model_path) # Set pad token id in model config if not set if not hasattr(config, 'pad_token_id') or config.pad_token_id is None: config.pad_token_id = tokenizer.pad_token_id model.config.pad_token_id = tokenizer.pad_token_id return model, tokenizer def test_model_and_tokenizer(model_path: str): """Test the loaded model and tokenizer""" model, tokenizer = load_emg_model_and_tokenizer(model_path) # Test encoding/decoding test_text = "Hello world, this is a test." print(f"Original text: {test_text}") # Encode encoded = tokenizer.encode(test_text) print(f"Encoded: {encoded}") # Decode decoded = tokenizer.decode(encoded, skip_special_tokens=True) print(f"Decoded: {decoded}") # Test model forward pass input_ids = torch.tensor([encoded]) with torch.no_grad(): outputs = model(input_ids) print(f"Model output shape: {outputs.logits.shape}") print(f"Model output type: {type(outputs)}") print("Model and tokenizer are working correctly!") return model, tokenizer if __name__ == "__main__": # Example usage model_path = "path/to/your/saved/model" # Replace with your model path # Register the classes register_emg_model() # Test loading try: model, tokenizer = test_model_and_tokenizer(model_path) print("✅ Model and tokenizer loaded successfully!") except Exception as e: print(f"❌ Error loading model: {e}")