Prisma / lm_eval_wrapper.py
y3i12's picture
Initial commit
56e82ec
"""
LM-eval harness wrapper for Circuit/Mirrored transformers.
Usage:
# Single model
python -m circuits.bench --checkpoint circuits/checkpoints/mirrored/best.pt --gpu 0
# Compare all architectures
python -m circuits.bench --compare --gpu 0
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from tqdm import tqdm
from lm_eval.api.model import LM
from lm_eval.api.instance import Instance
from .config import CircuitConfig
from .model import CircuitTransformer
from .mirrored import MirroredConfig, MirroredTransformer
from .graft_g2lu import load_g2lu_model
from .layers import build_word_start_table, compute_word_positions
from .data import get_tokenizer
def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
"""Migrate checkpoint state_dict to match current model architecture.
Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
"""
if any(k.startswith("_orig_mod.") for k in state_dict):
state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
model_keys = set(model.state_dict().keys())
ckpt_keys = set(state_dict.keys())
missing = model_keys - ckpt_keys
unexpected = ckpt_keys - model_keys
print(unexpected)
if not missing and not unexpected:
return state_dict # perfect match, no migration needed
migrated = dict(state_dict)
migrations = []
# SwiGLU → MirroredSwiGLU: w3 → gate_expand (dual_gate_middle upgrade)
for key in list(unexpected):
if ".ffn.gate_expand.weight" in key:
new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
if new_key in missing:
migrated[new_key] = migrated.pop(key)
missing.discard(new_key)
unexpected.discard(key)
migrations.append(f" {key}{new_key}")
if ".ffn.gate_compress.weight" in key:
new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
if new_key in missing:
migrated[new_key] = migrated.pop(key)
missing.discard(new_key)
unexpected.discard(key)
migrations.append(f" {key}{new_key}")
if migrations:
print(f"State dict migration ({len(migrations)} keys renamed):")
for m in migrations:
print(m)
# Report remaining missing keys (freshly initialized)
still_missing = model_keys - set(migrated.keys())
if still_missing:
print(f" New parameters (freshly initialized): {len(still_missing)}")
for k in sorted(still_missing):
print(f" {k}")
return migrated
def load_model(checkpoint_path: str, device: str = "cuda"):
"""Load any circuit model from checkpoint with auto-detection."""
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model_type = checkpoint.get("model_type", "standard")
if model_type == "graft_g2lu":
model = load_g2lu_model(checkpoint_path, device=device)
model.eval()
n_layers = len(model.g2lu_mlps)
arch_name = f"G²LU Graft ({checkpoint['pretrained_name']}, {n_layers}L)"
config = model.model.config # HF config
return model, config, arch_name, model_type
elif model_type == "mirrored":
if checkpoint["config"].get("dual_gate_middle"):
checkpoint["config"].pop("dual_gate_middle")
config = MirroredConfig.from_dict(checkpoint["config"])
model = MirroredTransformer(config)
arch_name = f"Mirrored ({model.total_virtual_layers}L)"
else:
config = CircuitConfig.from_dict(checkpoint["config"])
model = CircuitTransformer(config)
arch_name = f"Standard ({config.num_layers}L)"
# Strip _orig_mod. prefix from torch.compile'd checkpoints
state_dict = checkpoint["model"]
state_dict = _migrate_state_dict(state_dict, model)
model.load_state_dict(state_dict)
model = model.to(device).eval()
return model, config, arch_name, model_type
class CircuitLM(LM):
"""LM-eval wrapper for Circuit transformer family."""
def __init__(
self,
checkpoint: str,
device: str = "cuda",
batch_size: int = 1,
compile: bool = False,
):
super().__init__()
self.model, self.config, self.arch_name, self.model_type = load_model(
checkpoint, device
)
# Keep raw reference for .generate() — torch.compile only wraps forward()
self._raw_model = self.model
if compile == True:
self.model = torch.compile(self.model)
print(" torch.compile: enabled")
_ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
_tok_name = _ckpt.get("tokenizer_name", "gpt2")
del _ckpt
self.tokenizer = get_tokenizer(_tok_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self._device = device
self._batch_size = batch_size
# Build word-position table if model uses SemRoPE
self._word_start_table = None
word_rope_dims = getattr(self.config, 'word_rope_dims', 0)
if word_rope_dims == 0 and isinstance(self.config, dict):
word_rope_dims = self.config.get('word_rope_dims', 0)
if word_rope_dims > 0:
self._word_start_table = build_word_start_table(
self.tokenizer, len(self.tokenizer)
).to(device)
print(f" Word-position RoPE: {word_rope_dims} dims")
# Count parameters
n_params = sum(p.numel() for p in self.model.parameters())
print(f" Architecture: {self.arch_name}")
print(f" Parameters: {n_params / 1e6:.1f}M")
@property
def eot_token_id(self):
return self.tokenizer.eos_token_id
@property
def max_length(self):
return getattr(self.config, "max_seq_len", None) or getattr(self.config, "max_position_embeddings", 512)
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
return self._batch_size
@property
def device(self):
return self._device
def tok_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _model_call(self, input_ids: torch.Tensor):
with torch.inference_mode(), torch.autocast('cuda', dtype=torch.bfloat16, enabled=self._device != "cpu"):
word_positions = None
if self._word_start_table is not None:
word_positions = compute_word_positions(input_ids, self._word_start_table)
output = self.model(input_ids, use_cache=False, word_positions=word_positions)
return output["logits"]
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
results = []
for context_enc, continuation_enc in requests:
# Truncate from the left if too long
full_enc = context_enc + continuation_enc
if len(full_enc) > self.max_length:
excess = len(full_enc) - self.max_length
context_enc = context_enc[excess:]
full_enc = context_enc + continuation_enc
input_ids = torch.tensor(
[full_enc], dtype=torch.long, device=self._device
)
logits = self._model_call(input_ids)
ctx_len = len(context_enc)
cont_logits = logits[:, ctx_len - 1 : -1, :]
cont_tokens = input_ids[:, ctx_len:]
log_probs = F.log_softmax(cont_logits, dim=-1)
token_log_probs = log_probs.gather(
2, cont_tokens.unsqueeze(-1)
).squeeze(-1)
total_log_prob = token_log_probs.sum().item()
is_greedy = (cont_logits.argmax(dim=-1) == cont_tokens).all().item()
results.append((total_log_prob, is_greedy))
return results
def loglikelihood(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[tuple]:
results = []
for request in tqdm(
requests, desc="loglikelihood", disable=disable_tqdm
):
context, continuation = request.args
# Encode full text together to get correct tokenization,
# then split — sentencepiece tokenizes differently at string
# boundaries vs mid-sequence (the leading ▁ problem)
context_enc = self.tok_encode(context)
full_enc = self.tok_encode(context + continuation)
continuation_enc = full_enc[len(context_enc):]
if not continuation_enc:
# Edge case: continuation was absorbed into context tokens
# Fall back to encoding continuation separately
continuation_enc = self.tok_encode(continuation)
result = self._loglikelihood_tokens([(context_enc, continuation_enc)])
results.append(result[0])
return results
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
results = []
for request in tqdm(
requests, desc="loglikelihood_rolling", disable=disable_tqdm
):
text = request.args[0]
encoding = self.tok_encode(text)
total_log_prob = 0.0
max_len = self.max_length
for i in range(0, len(encoding), max_len):
chunk = encoding[i : i + max_len]
input_ids = torch.tensor(
[chunk], dtype=torch.long, device=self._device
)
logits = self._model_call(input_ids)
shift_logits = logits[:, :-1, :]
shift_labels = input_ids[:, 1:]
log_probs = F.log_softmax(shift_logits, dim=-1)
token_log_probs = log_probs.gather(
2, shift_labels.unsqueeze(-1)
).squeeze(-1)
total_log_prob += token_log_probs.sum().item()
results.append(total_log_prob)
return results
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
results = []
for request in tqdm(
requests, desc="generate_until", disable=disable_tqdm
):
context = request.args[0]
gen_kwargs = getattr(request, "kwargs", {}) or {}
until = gen_kwargs.get("until", [self.tokenizer.eos_token])
max_gen = gen_kwargs.get("max_gen_toks", self.max_gen_toks)
context_enc = self.tok_encode(context)
# Truncate context from left if needed
if len(context_enc) > self.max_length - max_gen:
context_enc = context_enc[-(self.max_length - max_gen) :]
input_ids = torch.tensor(
[context_enc], dtype=torch.long, device=self._device
)
if self.model_type == "graft_g2lu":
# Use HF's native generate with KV caching — much faster than
# manual token-by-token without cache (O(n) vs O(n²))
with torch.no_grad():
output_ids = self._raw_model.generate(
input_ids,
max_new_tokens=max_gen,
do_sample=False,
use_cache=True,
)
generated_text = self.tok_decode(
output_ids[0, input_ids.shape[1] :].tolist()
)
else:
generated_ids = input_ids.clone()
with torch.no_grad():
for _ in range(max_gen):
# Truncate if we exceed max_length
if generated_ids.shape[1] > self.max_length:
generated_ids = generated_ids[:, -self.max_length :]
logits = self._model_call(generated_ids)
next_logits = logits[:, -1, :]
next_token = next_logits.argmax(dim=-1, keepdim=True)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
if next_token.item() == self.eot_token_id:
break
current_text = self.tok_decode(
generated_ids[0, len(context_enc) :].tolist()
)
if any(s in current_text for s in until):
break
generated_text = self.tok_decode(
generated_ids[0, len(context_enc) :].tolist()
)
for stop in until:
if stop in generated_text:
generated_text = generated_text[: generated_text.index(stop)]
results.append(generated_text)
return results