TaoNet-mini-T2 / eval_lm_eval.py
StarMist0012's picture
Add files using upload-large-folder tool
388fd6e verified
"""Run lm-eval-harness on the packaged TaoNet-mini-T2 checkpoint.
This is a lightweight adapter around the custom TaoTrain checkpoint format.
By default it uses the fast full-sequence SSM path for benchmark scoring:
ssm_finite_tail_correction = True
ssm_kernel_mode = conv
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parent
TAOTRAIN_SRC = ROOT / "code" / "TaoTrain" / "src"
SSM_SRC = ROOT / "code" / "Taotern_SSM"
for path in (TAOTRAIN_SRC, SSM_SRC):
if str(path) not in sys.path:
sys.path.insert(0, str(path))
import torch
import torch.nn.functional as F
from tqdm import tqdm
from lm_eval import evaluator
from lm_eval.api.model import LM
from taoTrain.checkpointing.checkpoint import CheckpointManager
from taoTrain.config import ModelConfig
from taoTrain.inference.inferencer import Inferencer
from taoTrain.models import get_model
def apply_ssm_overrides(model: torch.nn.Module, *, kernel_mode: str, finite_tail: bool) -> int:
count = 0
for module in model.modules():
changed = False
if hasattr(module, "kernel_mode"):
module.kernel_mode = kernel_mode
changed = True
if hasattr(module, "finite_tail_correction"):
module.finite_tail_correction = finite_tail
changed = True
clear = getattr(module, "clear_kernel_cache", None)
if callable(clear):
clear()
if changed:
count += 1
return count
class TaoTrainLM(LM):
def __init__(
self,
checkpoint: str,
tokenizer: str,
device: str = "cuda",
dtype: str = "bfloat16",
max_length: int = 512,
ssm_kernel_mode: str = "conv",
finite_tail: bool = True,
eval_batch_size: int = 8,
) -> None:
super().__init__()
self._device = torch.device(device if device == "cpu" or torch.cuda.is_available() else "cpu")
self.dtype = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}[dtype]
self.max_length = max_length
self.ssm_kernel_mode = ssm_kernel_mode
self.finite_tail = finite_tail
self.eval_batch_size = eval_batch_size
self.checkpoint_path = Path(checkpoint)
self.tokenizer_path = Path(tokenizer)
checkpoint_obj = CheckpointManager(self.checkpoint_path.parent).load(
self.checkpoint_path,
device=self._device,
)
config_dict = checkpoint_obj.get("config", {})
model_config_dict = dict(config_dict.get("model", {}))
model_config_dict["ssm_finite_tail_correction"] = finite_tail
model_config_dict["ssm_kernel_mode"] = ssm_kernel_mode
self.tokenizer = Inferencer._load_tokenizer(self.tokenizer_path)
self.model = get_model(ModelConfig(**model_config_dict), device=self._device)
self.model.load_state_dict(checkpoint_obj["model_state"], strict=False)
self.model.to(self._device)
self.model.eval()
self.override_count = apply_ssm_overrides(
self.model,
kernel_mode=ssm_kernel_mode,
finite_tail=finite_tail,
)
@property
def tokenizer_name(self) -> str:
return str(self.tokenizer_path)
def tok_encode(self, text: str) -> list[int]:
return list(self.tokenizer.encode(text))
def tok_decode(self, tokens: list[int]) -> str:
return self.tokenizer.decode(tokens, skip_special_tokens=True)
def _prepare_ids(self, full_ids: list[int], continuation_start: int) -> tuple[list[int], int]:
bos = getattr(self.tokenizer, "bos_token_id", None)
if len(full_ids) < 2:
if bos is None:
return full_ids, continuation_start
full_ids = [bos] + full_ids
continuation_start += 1
if len(full_ids) > self.max_length:
drop = len(full_ids) - self.max_length
full_ids = full_ids[drop:]
continuation_start -= drop
continuation_start = max(1, continuation_start)
return full_ids, continuation_start
def _score_batch(self, batch: list[tuple[list[int], int]]) -> list[tuple[float, bool]]:
if not batch:
return []
pad_id = getattr(self.tokenizer, "pad_token_id", 0)
if pad_id is None or pad_id < 0:
pad_id = 0
max_len = max(len(ids) for ids, _ in batch)
if max_len < 2:
return [(0.0, True) for _ in batch]
rows = [ids + [pad_id] * (max_len - len(ids)) for ids, _ in batch]
input_ids = torch.tensor([row[:-1] for row in rows], dtype=torch.long, device=self._device)
target_ids = torch.tensor([row[1:] for row in rows], dtype=torch.long, device=self._device)
device_type = "cuda" if self._device.type == "cuda" else "cpu"
autocast_enabled = self._device.type == "cuda" and self.dtype in {torch.float16, torch.bfloat16}
with torch.inference_mode(), torch.autocast(
device_type=device_type,
dtype=self.dtype,
enabled=autocast_enabled,
):
apply_ssm_overrides(
self.model,
kernel_mode=self.ssm_kernel_mode,
finite_tail=self.finite_tail,
)
outputs = self.model(
input_ids=input_ids,
attention_mask=torch.ones_like(input_ids),
labels=None,
)
log_probs = F.log_softmax(outputs["logits"].float(), dim=-1)
results = []
for row_idx, (ids, continuation_start) in enumerate(batch):
continuation_start = max(1, continuation_start)
if continuation_start >= len(ids):
results.append((0.0, True))
continue
score_start = continuation_start - 1
score_end = len(ids) - 1
row_log_probs = log_probs[row_idx, score_start:score_end]
row_targets = target_ids[row_idx, score_start:score_end]
token_log_probs = row_log_probs.gather(1, row_targets.unsqueeze(1)).squeeze(1)
greedy = row_log_probs.argmax(dim=-1)
is_greedy = bool(torch.equal(greedy, row_targets))
results.append((float(token_log_probs.sum().item()), is_greedy))
return results
def _score_tokens(self, full_ids: list[int], continuation_start: int) -> tuple[float, bool]:
full_ids, continuation_start = self._prepare_ids(full_ids, continuation_start)
if continuation_start >= len(full_ids):
return 0.0, True
input_ids = torch.tensor(full_ids[:-1], dtype=torch.long, device=self._device).unsqueeze(0)
target_ids = torch.tensor(full_ids[1:], dtype=torch.long, device=self._device)
score_start = continuation_start - 1
device_type = "cuda" if self._device.type == "cuda" else "cpu"
autocast_enabled = self._device.type == "cuda" and self.dtype in {torch.float16, torch.bfloat16}
with torch.inference_mode(), torch.autocast(
device_type=device_type,
dtype=self.dtype,
enabled=autocast_enabled,
):
apply_ssm_overrides(
self.model,
kernel_mode=self.ssm_kernel_mode,
finite_tail=self.finite_tail,
)
outputs = self.model(
input_ids=input_ids,
attention_mask=torch.ones_like(input_ids),
labels=None,
)
logits = outputs["logits"][0]
log_probs = F.log_softmax(logits.float(), dim=-1)
cont_targets = target_ids[score_start:]
cont_log_probs = log_probs[score_start : score_start + len(cont_targets)]
token_log_probs = cont_log_probs.gather(1, cont_targets.unsqueeze(1)).squeeze(1)
greedy = cont_log_probs.argmax(dim=-1)
is_greedy = bool(torch.equal(greedy, cont_targets))
return float(token_log_probs.sum().item()), is_greedy
def loglikelihood(self, requests: list[Any]) -> list[tuple[float, bool]]:
prepared = []
for req in tqdm(requests, desc="loglikelihood"):
context, continuation = req.args
context_ids = self.tok_encode(context)
continuation_ids = self.tok_encode(continuation)
full_ids = self.tok_encode(context + continuation)
if len(full_ids) < len(continuation_ids):
full_ids = context_ids + continuation_ids
continuation_start = max(0, len(full_ids) - len(continuation_ids))
prepared.append(self._prepare_ids(full_ids, continuation_start))
results = []
for start in tqdm(range(0, len(prepared), self.eval_batch_size), desc="score_batches"):
results.extend(self._score_batch(prepared[start : start + self.eval_batch_size]))
return results
def loglikelihood_rolling(self, requests: list[Any]) -> list[float]:
results = []
for req in tqdm(requests, desc="rolling"):
(text,) = req.args
ids = self.tok_encode(text)
total = 0.0
step = max(1, self.max_length - 1)
bos = getattr(self.tokenizer, "bos_token_id", None)
prefix = [bos] if bos is not None else []
for start in range(0, len(ids), step):
chunk = prefix + ids[start : start + step]
score, _ = self._score_tokens(chunk, 1 if prefix else 0)
total += score
results.append(total)
return results
def generate_until(self, requests: list[Any]) -> list[str]:
from chat_ssm_fixed import generate
outputs = []
for req in tqdm(requests, desc="generate_until"):
context, gen_kwargs = req.args
max_gen_toks = int(gen_kwargs.get("max_gen_toks", 64))
until = gen_kwargs.get("until", [])
if isinstance(until, str):
until = [until]
text = generate(
self.model,
self.tokenizer,
context,
device=self._device,
dtype=self.dtype,
max_new_tokens=max_gen_toks,
temperature=0.7,
top_p=0.85,
repetition_penalty=1.2,
greedy=True,
)
for stop in until:
if stop and stop in text:
text = text.split(stop)[0]
outputs.append(text)
return outputs
def metric_value(task_result: dict[str, Any], preferred: list[str]) -> float | None:
for key in preferred:
if key in task_result:
value = task_result[key]
if isinstance(value, (int, float)):
return float(value)
return None
def json_safe(obj: Any) -> Any:
if isinstance(obj, dict):
return {str(k): json_safe(v) for k, v in obj.items()}
if isinstance(obj, list):
return [json_safe(v) for v in obj]
if isinstance(obj, tuple):
return [json_safe(v) for v in obj]
if isinstance(obj, (str, int, float, bool)) or obj is None:
return obj
return str(obj)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default=str(ROOT / "model" / "pretrain_final_model.pt"))
parser.add_argument("--tokenizer", default=str(ROOT / "tokenizer" / "tokenizer.model"))
parser.add_argument(
"--tasks",
default="mmlu,hellaswag,arc_easy,arc_challenge,piqa,winogrande",
)
parser.add_argument("--num-fewshot", type=int, default=0)
parser.add_argument("--limit", type=float, default=None)
parser.add_argument("--device", default="cuda")
parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16")
parser.add_argument("--ssm-kernel-mode", choices=["conv", "recurrent"], default="conv")
parser.add_argument("--finite-tail", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--eval-batch-size", type=int, default=8)
parser.add_argument("--output", default=str(ROOT / "artifacts" / "lm_eval_results.json"))
args = parser.parse_args()
model = TaoTrainLM(
checkpoint=args.checkpoint,
tokenizer=args.tokenizer,
device=args.device,
dtype=args.dtype,
ssm_kernel_mode=args.ssm_kernel_mode,
finite_tail=args.finite_tail,
eval_batch_size=args.eval_batch_size,
)
print(f"device={model.device}")
print(f"ssm_overrides={model.override_count}")
print(f"ssm_kernel_mode={args.ssm_kernel_mode}")
print(f"finite_tail={args.finite_tail}")
task_names = [item.strip() for item in args.tasks.split(",") if item.strip()]
results = evaluator.simple_evaluate(
model=model,
tasks=task_names,
num_fewshot=args.num_fewshot,
limit=args.limit,
batch_size=1,
log_samples=False,
verbosity="INFO",
)
if results is None:
raise RuntimeError("lm-eval returned no results")
output = Path(args.output)
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(json.dumps(json_safe(results), indent=2, ensure_ascii=False), encoding="utf-8")
preferred = {
"mmlu": ["acc,none", "acc"],
"hellaswag": ["acc_norm,none", "acc_norm", "acc,none", "acc"],
"arc_easy": ["acc_norm,none", "acc_norm", "acc,none", "acc"],
"arc_challenge": ["acc_norm,none", "acc_norm", "acc,none", "acc"],
"piqa": ["acc_norm,none", "acc_norm", "acc,none", "acc"],
"winogrande": ["acc,none", "acc"],
}
print("\nTask benchmark:")
values = []
for task in task_names:
task_result = results["results"].get(task, {})
value = metric_value(task_result, preferred.get(task, ["acc_norm,none", "acc,none", "acc_norm", "acc"]))
if value is not None:
values.append(value)
print(f" {task}: {value:.4f}")
if values:
print(f" mean_primary_score: {sum(values) / len(values):.4f}")
print(f" num_fewshot: {args.num_fewshot}")
if args.limit is not None:
print(f" limit: {args.limit}")
if __name__ == "__main__":
main()