# modeling_i3.py import os import json import torch from torch import nn from transformers import PreTrainedModel, PretrainedConfig from i3_model import i3Model, ChunkTokenizer # ====================================================================== # I3 Configuration for Transformers # ====================================================================== class I3Config(PretrainedConfig): model_type = "i3" def __init__(self, **kwargs): super().__init__(**kwargs) # ====================================================================== # I3 For Causal Language Modeling (HuggingFace Wrapper) # ====================================================================== class I3ForCausalLM(PreTrainedModel): config_class = I3Config base_model_prefix = "i3" def __init__(self, config): super().__init__(config) self.i3 = i3Model( vocab_size=config.vocab_size, d_model=getattr(config, "d_model", 512), n_heads=getattr(config, "n_heads", 16), max_seq_len=getattr(config, "max_seq_len", 256), d_state=getattr(config, "d_state", 32) ) # Tokenizer reference (optional, for convenience) self.tokenizer = None self.post_init() def forward(self, input_ids, labels=None): logits, loss = self.i3(input_ids, targets=labels) output = {"logits": logits} if loss is not None: output["loss"] = loss return output @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): """ Load model weights and config from HF repo or local folder. Also loads chunk tokenizer if present. """ # Load config.json config_path = os.path.join(pretrained_model_name_or_path, "config.json") if not os.path.exists(config_path): raise FileNotFoundError(f"Cannot find config.json at {config_path}") with open(config_path, "r") as f: config_dict = json.load(f) config = I3Config(**config_dict) model = cls(config) # Load model weights bin_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") safe_path = os.path.join(pretrained_model_name_or_path, "model.safetensors") if os.path.exists(safe_path): try: import safetensors.torch state_dict = safetensors.torch.load_file(safe_path) model.load_state_dict(state_dict, strict=True) except ImportError: raise ImportError("Please install safetensors to load .safetensors files") elif os.path.exists(bin_path): state_dict = torch.load(bin_path, map_location="cpu") model.load_state_dict(state_dict, strict=True) else: raise FileNotFoundError("No model file found in the provided path") # Load tokenizer if chunk_vocab_combined.json exists vocab_path = os.path.join(pretrained_model_name_or_path, "chunk_vocab_combined.json") if os.path.exists(vocab_path): tokenizer = ChunkTokenizer() tokenizer.load(vocab_path) model.tokenizer = tokenizer return model