Text Generation
Transformers
PyTorch
Safetensors
English
i3
i3-architecture
hybrid-model
rwkv-mamba
custom_code
i3-80m / modeling_i3.py
FlameF0X's picture
Create modeling_i3.py
420490d verified
# modeling_i3.py
import os
import json
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from i3_model import i3Model, ChunkTokenizer
# ======================================================================
# I3 Configuration for Transformers
# ======================================================================
class I3Config(PretrainedConfig):
model_type = "i3"
def __init__(self, **kwargs):
super().__init__(**kwargs)
# ======================================================================
# I3 For Causal Language Modeling (HuggingFace Wrapper)
# ======================================================================
class I3ForCausalLM(PreTrainedModel):
config_class = I3Config
base_model_prefix = "i3"
def __init__(self, config):
super().__init__(config)
self.i3 = i3Model(
vocab_size=config.vocab_size,
d_model=getattr(config, "d_model", 512),
n_heads=getattr(config, "n_heads", 16),
max_seq_len=getattr(config, "max_seq_len", 256),
d_state=getattr(config, "d_state", 32)
)
# Tokenizer reference (optional, for convenience)
self.tokenizer = None
self.post_init()
def forward(self, input_ids, labels=None):
logits, loss = self.i3(input_ids, targets=labels)
output = {"logits": logits}
if loss is not None:
output["loss"] = loss
return output
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
"""
Load model weights and config from HF repo or local folder.
Also loads chunk tokenizer if present.
"""
# Load config.json
config_path = os.path.join(pretrained_model_name_or_path, "config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Cannot find config.json at {config_path}")
with open(config_path, "r") as f:
config_dict = json.load(f)
config = I3Config(**config_dict)
model = cls(config)
# Load model weights
bin_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
safe_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
if os.path.exists(safe_path):
try:
import safetensors.torch
state_dict = safetensors.torch.load_file(safe_path)
model.load_state_dict(state_dict, strict=True)
except ImportError:
raise ImportError("Please install safetensors to load .safetensors files")
elif os.path.exists(bin_path):
state_dict = torch.load(bin_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
else:
raise FileNotFoundError("No model file found in the provided path")
# Load tokenizer if chunk_vocab_combined.json exists
vocab_path = os.path.join(pretrained_model_name_or_path, "chunk_vocab_combined.json")
if os.path.exists(vocab_path):
tokenizer = ChunkTokenizer()
tokenizer.load(vocab_path)
model.tokenizer = tokenizer
return model