""" LM-eval harness wrapper for Circuit/Mirrored transformers. Usage: # Single model python -m circuits.bench --checkpoint circuits/checkpoints/mirrored/best.pt --gpu 0 # Compare all architectures python -m circuits.bench --compare --gpu 0 """ import torch import torch.nn as nn import torch.nn.functional as F from typing import List from tqdm import tqdm from lm_eval.api.model import LM from lm_eval.api.instance import Instance from .config import CircuitConfig from .model import CircuitTransformer from .mirrored import MirroredConfig, MirroredTransformer from .graft_g2lu import load_g2lu_model from .layers import build_word_start_table, compute_word_positions from .data import get_tokenizer def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict: """Migrate checkpoint state_dict to match current model architecture. Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle). """ if any(k.startswith("_orig_mod.") for k in state_dict): state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()} model_keys = set(model.state_dict().keys()) ckpt_keys = set(state_dict.keys()) missing = model_keys - ckpt_keys unexpected = ckpt_keys - model_keys print(unexpected) if not missing and not unexpected: return state_dict # perfect match, no migration needed migrated = dict(state_dict) migrations = [] # SwiGLU → MirroredSwiGLU: w3 → gate_expand (dual_gate_middle upgrade) for key in list(unexpected): if ".ffn.gate_expand.weight" in key: new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight") if new_key in missing: migrated[new_key] = migrated.pop(key) missing.discard(new_key) unexpected.discard(key) migrations.append(f" {key} → {new_key}") if ".ffn.gate_compress.weight" in key: new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight") if new_key in missing: migrated[new_key] = migrated.pop(key) missing.discard(new_key) unexpected.discard(key) migrations.append(f" {key} → {new_key}") if migrations: print(f"State dict migration ({len(migrations)} keys renamed):") for m in migrations: print(m) # Report remaining missing keys (freshly initialized) still_missing = model_keys - set(migrated.keys()) if still_missing: print(f" New parameters (freshly initialized): {len(still_missing)}") for k in sorted(still_missing): print(f" {k}") return migrated def load_model(checkpoint_path: str, device: str = "cuda"): """Load any circuit model from checkpoint with auto-detection.""" checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) model_type = checkpoint.get("model_type", "standard") if model_type == "graft_g2lu": model = load_g2lu_model(checkpoint_path, device=device) model.eval() n_layers = len(model.g2lu_mlps) arch_name = f"G²LU Graft ({checkpoint['pretrained_name']}, {n_layers}L)" config = model.model.config # HF config return model, config, arch_name, model_type elif model_type == "mirrored": if checkpoint["config"].get("dual_gate_middle"): checkpoint["config"].pop("dual_gate_middle") config = MirroredConfig.from_dict(checkpoint["config"]) model = MirroredTransformer(config) arch_name = f"Mirrored ({model.total_virtual_layers}L)" else: config = CircuitConfig.from_dict(checkpoint["config"]) model = CircuitTransformer(config) arch_name = f"Standard ({config.num_layers}L)" # Strip _orig_mod. prefix from torch.compile'd checkpoints state_dict = checkpoint["model"] state_dict = _migrate_state_dict(state_dict, model) model.load_state_dict(state_dict) model = model.to(device).eval() return model, config, arch_name, model_type class CircuitLM(LM): """LM-eval wrapper for Circuit transformer family.""" def __init__( self, checkpoint: str, device: str = "cuda", batch_size: int = 1, compile: bool = False, ): super().__init__() self.model, self.config, self.arch_name, self.model_type = load_model( checkpoint, device ) # Keep raw reference for .generate() — torch.compile only wraps forward() self._raw_model = self.model if compile == True: self.model = torch.compile(self.model) print(" torch.compile: enabled") _ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False) _tok_name = _ckpt.get("tokenizer_name", "gpt2") del _ckpt self.tokenizer = get_tokenizer(_tok_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self._device = device self._batch_size = batch_size # Build word-position table if model uses SemRoPE self._word_start_table = None word_rope_dims = getattr(self.config, 'word_rope_dims', 0) if word_rope_dims == 0 and isinstance(self.config, dict): word_rope_dims = self.config.get('word_rope_dims', 0) if word_rope_dims > 0: self._word_start_table = build_word_start_table( self.tokenizer, len(self.tokenizer) ).to(device) print(f" Word-position RoPE: {word_rope_dims} dims") # Count parameters n_params = sum(p.numel() for p in self.model.parameters()) print(f" Architecture: {self.arch_name}") print(f" Parameters: {n_params / 1e6:.1f}M") @property def eot_token_id(self): return self.tokenizer.eos_token_id @property def max_length(self): return getattr(self.config, "max_seq_len", None) or getattr(self.config, "max_position_embeddings", 512) @property def max_gen_toks(self): return 256 @property def batch_size(self): return self._batch_size @property def device(self): return self._device def tok_encode(self, string: str) -> List[int]: return self.tokenizer.encode(string, add_special_tokens=False) def tok_decode(self, tokens: List[int]) -> str: return self.tokenizer.decode(tokens) def _model_call(self, input_ids: torch.Tensor): with torch.inference_mode(), torch.autocast('cuda', dtype=torch.bfloat16, enabled=self._device != "cpu"): word_positions = None if self._word_start_table is not None: word_positions = compute_word_positions(input_ids, self._word_start_table) output = self.model(input_ids, use_cache=False, word_positions=word_positions) return output["logits"] def _loglikelihood_tokens(self, requests, disable_tqdm=False): results = [] for context_enc, continuation_enc in requests: # Truncate from the left if too long full_enc = context_enc + continuation_enc if len(full_enc) > self.max_length: excess = len(full_enc) - self.max_length context_enc = context_enc[excess:] full_enc = context_enc + continuation_enc input_ids = torch.tensor( [full_enc], dtype=torch.long, device=self._device ) logits = self._model_call(input_ids) ctx_len = len(context_enc) cont_logits = logits[:, ctx_len - 1 : -1, :] cont_tokens = input_ids[:, ctx_len:] log_probs = F.log_softmax(cont_logits, dim=-1) token_log_probs = log_probs.gather( 2, cont_tokens.unsqueeze(-1) ).squeeze(-1) total_log_prob = token_log_probs.sum().item() is_greedy = (cont_logits.argmax(dim=-1) == cont_tokens).all().item() results.append((total_log_prob, is_greedy)) return results def loglikelihood( self, requests: List[Instance], disable_tqdm: bool = False ) -> List[tuple]: results = [] for request in tqdm( requests, desc="loglikelihood", disable=disable_tqdm ): context, continuation = request.args # Encode full text together to get correct tokenization, # then split — sentencepiece tokenizes differently at string # boundaries vs mid-sequence (the leading ▁ problem) context_enc = self.tok_encode(context) full_enc = self.tok_encode(context + continuation) continuation_enc = full_enc[len(context_enc):] if not continuation_enc: # Edge case: continuation was absorbed into context tokens # Fall back to encoding continuation separately continuation_enc = self.tok_encode(continuation) result = self._loglikelihood_tokens([(context_enc, continuation_enc)]) results.append(result[0]) return results def loglikelihood_rolling( self, requests: List[Instance], disable_tqdm: bool = False ) -> List[float]: results = [] for request in tqdm( requests, desc="loglikelihood_rolling", disable=disable_tqdm ): text = request.args[0] encoding = self.tok_encode(text) total_log_prob = 0.0 max_len = self.max_length for i in range(0, len(encoding), max_len): chunk = encoding[i : i + max_len] input_ids = torch.tensor( [chunk], dtype=torch.long, device=self._device ) logits = self._model_call(input_ids) shift_logits = logits[:, :-1, :] shift_labels = input_ids[:, 1:] log_probs = F.log_softmax(shift_logits, dim=-1) token_log_probs = log_probs.gather( 2, shift_labels.unsqueeze(-1) ).squeeze(-1) total_log_prob += token_log_probs.sum().item() results.append(total_log_prob) return results def generate_until( self, requests: List[Instance], disable_tqdm: bool = False ) -> List[str]: results = [] for request in tqdm( requests, desc="generate_until", disable=disable_tqdm ): context = request.args[0] gen_kwargs = getattr(request, "kwargs", {}) or {} until = gen_kwargs.get("until", [self.tokenizer.eos_token]) max_gen = gen_kwargs.get("max_gen_toks", self.max_gen_toks) context_enc = self.tok_encode(context) # Truncate context from left if needed if len(context_enc) > self.max_length - max_gen: context_enc = context_enc[-(self.max_length - max_gen) :] input_ids = torch.tensor( [context_enc], dtype=torch.long, device=self._device ) if self.model_type == "graft_g2lu": # Use HF's native generate with KV caching — much faster than # manual token-by-token without cache (O(n) vs O(n²)) with torch.no_grad(): output_ids = self._raw_model.generate( input_ids, max_new_tokens=max_gen, do_sample=False, use_cache=True, ) generated_text = self.tok_decode( output_ids[0, input_ids.shape[1] :].tolist() ) else: generated_ids = input_ids.clone() with torch.no_grad(): for _ in range(max_gen): # Truncate if we exceed max_length if generated_ids.shape[1] > self.max_length: generated_ids = generated_ids[:, -self.max_length :] logits = self._model_call(generated_ids) next_logits = logits[:, -1, :] next_token = next_logits.argmax(dim=-1, keepdim=True) generated_ids = torch.cat([generated_ids, next_token], dim=1) if next_token.item() == self.eot_token_id: break current_text = self.tok_decode( generated_ids[0, len(context_enc) :].tolist() ) if any(s in current_text for s in until): break generated_text = self.tok_decode( generated_ids[0, len(context_enc) :].tolist() ) for stop in until: if stop in generated_text: generated_text = generated_text[: generated_text.index(stop)] results.append(generated_text) return results