Stentor2-12M-Preview / load_stentor2.py
StentorLabs's picture
Update load_stentor2.py
ac7b750 verified
from transformers import LlamaForCausalLM, LlamaConfig
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import torch
MODEL_ID = "StentorLabs/Stentor2-12M-Preview"
# ── Minimal TokenMonster adapter ───────────────────────────────────────────────
class _Tokenizer:
def __init__(self, vocab_path):
import subprocess, sys
subprocess.run([sys.executable, "-m", "pip", "install", "tokenmonster", "-q"], check=True)
import tokenmonster
self._vocab = tokenmonster.load(vocab_path)
self._vocab.add_special_token("</s>")
self._vocab.add_special_token("<s>")
self.eos_token = "</s>"
self.bos_token = "<s>"
self.pad_token = "</s>"
self.eos_token_id = int(self._vocab.token_to_id("</s>"))
self.bos_token_id = int(self._vocab.token_to_id("<s>"))
self.pad_token_id = self.eos_token_id
def encode(self, text):
"""Returns a plain list of int token ids."""
ids = self._vocab.tokenize(text)
if hasattr(ids, "tolist"):
ids = ids.tolist()
return [int(x) for x in ids]
def decode(self, token_ids):
"""Accepts a list or 1-D tensor of int token ids."""
if hasattr(token_ids, "tolist"):
token_ids = token_ids.tolist()
return self._vocab.decode(token_ids)
def __len__(self):
return len(self._vocab)
# ── Main loader ────────────────────────────────────────────────────────────────
def load_stentor2(model_id=MODEL_ID, dtype=torch.float32):
# Tokenizer β€” download vocab binary directly
vocab_path = hf_hub_download(repo_id=model_id, filename="tokenmonster.vocab")
tokenizer = _Tokenizer(vocab_path)
# Model
config = LlamaConfig.from_pretrained(model_id)
model = LlamaForCausalLM(config)
weight_path = hf_hub_download(repo_id=model_id, filename="model.safetensors")
raw_sd = load_file(weight_path)
# Remap weight_master β†’ weight (INT8 QAT training artifact)
sd = {}
masters = {k for k in raw_sd if k.endswith(".weight_master")}
skip = {k[:-len("_master")] for k in masters}
for k, v in raw_sd.items():
if k.endswith(".weight_master"): sd[k[:-len("_master")]] = v
elif k not in skip: sd[k] = v
model.load_state_dict(sd, strict=False)
model.to(dtype).eval()
return model, tokenizer