|
|
|
|
|
import os |
|
|
import json |
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from i3_model import i3Model, ChunkTokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class I3Config(PretrainedConfig): |
|
|
model_type = "i3" |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class I3ForCausalLM(PreTrainedModel): |
|
|
config_class = I3Config |
|
|
base_model_prefix = "i3" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.i3 = i3Model( |
|
|
vocab_size=config.vocab_size, |
|
|
d_model=getattr(config, "d_model", 512), |
|
|
n_heads=getattr(config, "n_heads", 16), |
|
|
max_seq_len=getattr(config, "max_seq_len", 256), |
|
|
d_state=getattr(config, "d_state", 32) |
|
|
) |
|
|
|
|
|
self.tokenizer = None |
|
|
self.post_init() |
|
|
|
|
|
def forward(self, input_ids, labels=None): |
|
|
logits, loss = self.i3(input_ids, targets=labels) |
|
|
output = {"logits": logits} |
|
|
if loss is not None: |
|
|
output["loss"] = loss |
|
|
return output |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
|
|
""" |
|
|
Load model weights and config from HF repo or local folder. |
|
|
Also loads chunk tokenizer if present. |
|
|
""" |
|
|
|
|
|
config_path = os.path.join(pretrained_model_name_or_path, "config.json") |
|
|
if not os.path.exists(config_path): |
|
|
raise FileNotFoundError(f"Cannot find config.json at {config_path}") |
|
|
with open(config_path, "r") as f: |
|
|
config_dict = json.load(f) |
|
|
|
|
|
config = I3Config(**config_dict) |
|
|
model = cls(config) |
|
|
|
|
|
|
|
|
bin_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") |
|
|
safe_path = os.path.join(pretrained_model_name_or_path, "model.safetensors") |
|
|
|
|
|
if os.path.exists(safe_path): |
|
|
try: |
|
|
import safetensors.torch |
|
|
state_dict = safetensors.torch.load_file(safe_path) |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
except ImportError: |
|
|
raise ImportError("Please install safetensors to load .safetensors files") |
|
|
elif os.path.exists(bin_path): |
|
|
state_dict = torch.load(bin_path, map_location="cpu") |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
else: |
|
|
raise FileNotFoundError("No model file found in the provided path") |
|
|
|
|
|
|
|
|
vocab_path = os.path.join(pretrained_model_name_or_path, "chunk_vocab_combined.json") |
|
|
if os.path.exists(vocab_path): |
|
|
tokenizer = ChunkTokenizer() |
|
|
tokenizer.load(vocab_path) |
|
|
model.tokenizer = tokenizer |
|
|
|
|
|
return model |
|
|
|