"""Run lm-eval-harness on the packaged TaoNet-mini-T2 checkpoint. This is a lightweight adapter around the custom TaoTrain checkpoint format. By default it uses the fast full-sequence SSM path for benchmark scoring: ssm_finite_tail_correction = True ssm_kernel_mode = conv """ from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Any ROOT = Path(__file__).resolve().parent TAOTRAIN_SRC = ROOT / "code" / "TaoTrain" / "src" SSM_SRC = ROOT / "code" / "Taotern_SSM" for path in (TAOTRAIN_SRC, SSM_SRC): if str(path) not in sys.path: sys.path.insert(0, str(path)) import torch import torch.nn.functional as F from tqdm import tqdm from lm_eval import evaluator from lm_eval.api.model import LM from taoTrain.checkpointing.checkpoint import CheckpointManager from taoTrain.config import ModelConfig from taoTrain.inference.inferencer import Inferencer from taoTrain.models import get_model def apply_ssm_overrides(model: torch.nn.Module, *, kernel_mode: str, finite_tail: bool) -> int: count = 0 for module in model.modules(): changed = False if hasattr(module, "kernel_mode"): module.kernel_mode = kernel_mode changed = True if hasattr(module, "finite_tail_correction"): module.finite_tail_correction = finite_tail changed = True clear = getattr(module, "clear_kernel_cache", None) if callable(clear): clear() if changed: count += 1 return count class TaoTrainLM(LM): def __init__( self, checkpoint: str, tokenizer: str, device: str = "cuda", dtype: str = "bfloat16", max_length: int = 512, ssm_kernel_mode: str = "conv", finite_tail: bool = True, eval_batch_size: int = 8, ) -> None: super().__init__() self._device = torch.device(device if device == "cpu" or torch.cuda.is_available() else "cpu") self.dtype = { "float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16, }[dtype] self.max_length = max_length self.ssm_kernel_mode = ssm_kernel_mode self.finite_tail = finite_tail self.eval_batch_size = eval_batch_size self.checkpoint_path = Path(checkpoint) self.tokenizer_path = Path(tokenizer) checkpoint_obj = CheckpointManager(self.checkpoint_path.parent).load( self.checkpoint_path, device=self._device, ) config_dict = checkpoint_obj.get("config", {}) model_config_dict = dict(config_dict.get("model", {})) model_config_dict["ssm_finite_tail_correction"] = finite_tail model_config_dict["ssm_kernel_mode"] = ssm_kernel_mode self.tokenizer = Inferencer._load_tokenizer(self.tokenizer_path) self.model = get_model(ModelConfig(**model_config_dict), device=self._device) self.model.load_state_dict(checkpoint_obj["model_state"], strict=False) self.model.to(self._device) self.model.eval() self.override_count = apply_ssm_overrides( self.model, kernel_mode=ssm_kernel_mode, finite_tail=finite_tail, ) @property def tokenizer_name(self) -> str: return str(self.tokenizer_path) def tok_encode(self, text: str) -> list[int]: return list(self.tokenizer.encode(text)) def tok_decode(self, tokens: list[int]) -> str: return self.tokenizer.decode(tokens, skip_special_tokens=True) def _prepare_ids(self, full_ids: list[int], continuation_start: int) -> tuple[list[int], int]: bos = getattr(self.tokenizer, "bos_token_id", None) if len(full_ids) < 2: if bos is None: return full_ids, continuation_start full_ids = [bos] + full_ids continuation_start += 1 if len(full_ids) > self.max_length: drop = len(full_ids) - self.max_length full_ids = full_ids[drop:] continuation_start -= drop continuation_start = max(1, continuation_start) return full_ids, continuation_start def _score_batch(self, batch: list[tuple[list[int], int]]) -> list[tuple[float, bool]]: if not batch: return [] pad_id = getattr(self.tokenizer, "pad_token_id", 0) if pad_id is None or pad_id < 0: pad_id = 0 max_len = max(len(ids) for ids, _ in batch) if max_len < 2: return [(0.0, True) for _ in batch] rows = [ids + [pad_id] * (max_len - len(ids)) for ids, _ in batch] input_ids = torch.tensor([row[:-1] for row in rows], dtype=torch.long, device=self._device) target_ids = torch.tensor([row[1:] for row in rows], dtype=torch.long, device=self._device) device_type = "cuda" if self._device.type == "cuda" else "cpu" autocast_enabled = self._device.type == "cuda" and self.dtype in {torch.float16, torch.bfloat16} with torch.inference_mode(), torch.autocast( device_type=device_type, dtype=self.dtype, enabled=autocast_enabled, ): apply_ssm_overrides( self.model, kernel_mode=self.ssm_kernel_mode, finite_tail=self.finite_tail, ) outputs = self.model( input_ids=input_ids, attention_mask=torch.ones_like(input_ids), labels=None, ) log_probs = F.log_softmax(outputs["logits"].float(), dim=-1) results = [] for row_idx, (ids, continuation_start) in enumerate(batch): continuation_start = max(1, continuation_start) if continuation_start >= len(ids): results.append((0.0, True)) continue score_start = continuation_start - 1 score_end = len(ids) - 1 row_log_probs = log_probs[row_idx, score_start:score_end] row_targets = target_ids[row_idx, score_start:score_end] token_log_probs = row_log_probs.gather(1, row_targets.unsqueeze(1)).squeeze(1) greedy = row_log_probs.argmax(dim=-1) is_greedy = bool(torch.equal(greedy, row_targets)) results.append((float(token_log_probs.sum().item()), is_greedy)) return results def _score_tokens(self, full_ids: list[int], continuation_start: int) -> tuple[float, bool]: full_ids, continuation_start = self._prepare_ids(full_ids, continuation_start) if continuation_start >= len(full_ids): return 0.0, True input_ids = torch.tensor(full_ids[:-1], dtype=torch.long, device=self._device).unsqueeze(0) target_ids = torch.tensor(full_ids[1:], dtype=torch.long, device=self._device) score_start = continuation_start - 1 device_type = "cuda" if self._device.type == "cuda" else "cpu" autocast_enabled = self._device.type == "cuda" and self.dtype in {torch.float16, torch.bfloat16} with torch.inference_mode(), torch.autocast( device_type=device_type, dtype=self.dtype, enabled=autocast_enabled, ): apply_ssm_overrides( self.model, kernel_mode=self.ssm_kernel_mode, finite_tail=self.finite_tail, ) outputs = self.model( input_ids=input_ids, attention_mask=torch.ones_like(input_ids), labels=None, ) logits = outputs["logits"][0] log_probs = F.log_softmax(logits.float(), dim=-1) cont_targets = target_ids[score_start:] cont_log_probs = log_probs[score_start : score_start + len(cont_targets)] token_log_probs = cont_log_probs.gather(1, cont_targets.unsqueeze(1)).squeeze(1) greedy = cont_log_probs.argmax(dim=-1) is_greedy = bool(torch.equal(greedy, cont_targets)) return float(token_log_probs.sum().item()), is_greedy def loglikelihood(self, requests: list[Any]) -> list[tuple[float, bool]]: prepared = [] for req in tqdm(requests, desc="loglikelihood"): context, continuation = req.args context_ids = self.tok_encode(context) continuation_ids = self.tok_encode(continuation) full_ids = self.tok_encode(context + continuation) if len(full_ids) < len(continuation_ids): full_ids = context_ids + continuation_ids continuation_start = max(0, len(full_ids) - len(continuation_ids)) prepared.append(self._prepare_ids(full_ids, continuation_start)) results = [] for start in tqdm(range(0, len(prepared), self.eval_batch_size), desc="score_batches"): results.extend(self._score_batch(prepared[start : start + self.eval_batch_size])) return results def loglikelihood_rolling(self, requests: list[Any]) -> list[float]: results = [] for req in tqdm(requests, desc="rolling"): (text,) = req.args ids = self.tok_encode(text) total = 0.0 step = max(1, self.max_length - 1) bos = getattr(self.tokenizer, "bos_token_id", None) prefix = [bos] if bos is not None else [] for start in range(0, len(ids), step): chunk = prefix + ids[start : start + step] score, _ = self._score_tokens(chunk, 1 if prefix else 0) total += score results.append(total) return results def generate_until(self, requests: list[Any]) -> list[str]: from chat_ssm_fixed import generate outputs = [] for req in tqdm(requests, desc="generate_until"): context, gen_kwargs = req.args max_gen_toks = int(gen_kwargs.get("max_gen_toks", 64)) until = gen_kwargs.get("until", []) if isinstance(until, str): until = [until] text = generate( self.model, self.tokenizer, context, device=self._device, dtype=self.dtype, max_new_tokens=max_gen_toks, temperature=0.7, top_p=0.85, repetition_penalty=1.2, greedy=True, ) for stop in until: if stop and stop in text: text = text.split(stop)[0] outputs.append(text) return outputs def metric_value(task_result: dict[str, Any], preferred: list[str]) -> float | None: for key in preferred: if key in task_result: value = task_result[key] if isinstance(value, (int, float)): return float(value) return None def json_safe(obj: Any) -> Any: if isinstance(obj, dict): return {str(k): json_safe(v) for k, v in obj.items()} if isinstance(obj, list): return [json_safe(v) for v in obj] if isinstance(obj, tuple): return [json_safe(v) for v in obj] if isinstance(obj, (str, int, float, bool)) or obj is None: return obj return str(obj) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", default=str(ROOT / "model" / "pretrain_final_model.pt")) parser.add_argument("--tokenizer", default=str(ROOT / "tokenizer" / "tokenizer.model")) parser.add_argument( "--tasks", default="mmlu,hellaswag,arc_easy,arc_challenge,piqa,winogrande", ) parser.add_argument("--num-fewshot", type=int, default=0) parser.add_argument("--limit", type=float, default=None) parser.add_argument("--device", default="cuda") parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16") parser.add_argument("--ssm-kernel-mode", choices=["conv", "recurrent"], default="conv") parser.add_argument("--finite-tail", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--eval-batch-size", type=int, default=8) parser.add_argument("--output", default=str(ROOT / "artifacts" / "lm_eval_results.json")) args = parser.parse_args() model = TaoTrainLM( checkpoint=args.checkpoint, tokenizer=args.tokenizer, device=args.device, dtype=args.dtype, ssm_kernel_mode=args.ssm_kernel_mode, finite_tail=args.finite_tail, eval_batch_size=args.eval_batch_size, ) print(f"device={model.device}") print(f"ssm_overrides={model.override_count}") print(f"ssm_kernel_mode={args.ssm_kernel_mode}") print(f"finite_tail={args.finite_tail}") task_names = [item.strip() for item in args.tasks.split(",") if item.strip()] results = evaluator.simple_evaluate( model=model, tasks=task_names, num_fewshot=args.num_fewshot, limit=args.limit, batch_size=1, log_samples=False, verbosity="INFO", ) if results is None: raise RuntimeError("lm-eval returned no results") output = Path(args.output) output.parent.mkdir(parents=True, exist_ok=True) output.write_text(json.dumps(json_safe(results), indent=2, ensure_ascii=False), encoding="utf-8") preferred = { "mmlu": ["acc,none", "acc"], "hellaswag": ["acc_norm,none", "acc_norm", "acc,none", "acc"], "arc_easy": ["acc_norm,none", "acc_norm", "acc,none", "acc"], "arc_challenge": ["acc_norm,none", "acc_norm", "acc,none", "acc"], "piqa": ["acc_norm,none", "acc_norm", "acc,none", "acc"], "winogrande": ["acc,none", "acc"], } print("\nTask benchmark:") values = [] for task in task_names: task_result = results["results"].get(task, {}) value = metric_value(task_result, preferred.get(task, ["acc_norm,none", "acc,none", "acc_norm", "acc"])) if value is not None: values.append(value) print(f" {task}: {value:.4f}") if values: print(f" mean_primary_score: {sum(values) / len(values):.4f}") print(f" num_fewshot: {args.num_fewshot}") if args.limit is not None: print(f" limit: {args.limit}") if __name__ == "__main__": main()