"""lm-eval wrapper for Atom2.7m checkpoints. The standard ``hf`` lm-eval model does not use the fusion tokenizer wrapper and does not pass arithmetic feature streams. This model keeps lm-eval's log-likelihood interface while encoding with ``tokenizer_utils.load_tokenizer`` and forwarding ``place_ids`` and ``role_ids``. """ from __future__ import annotations from contextlib import nullcontext from pathlib import Path from typing import Any import torch import torch.nn.functional as F from lm_eval.api.model import LM from lm_eval.api.registry import register_model from tqdm import tqdm from transformers import AutoModelForCausalLM from tokenizer_utils import EOT_ID, FusionTokenizer, load_tokenizer def _parse_bool(value: Any, default: bool = False) -> bool: if value is None: return default if isinstance(value, bool): return value return str(value).strip().lower() in {"1", "true", "yes", "on"} def _parse_batch_size(value: int | str | None, max_batch_size: int | None) -> int: if value is None: return 1 if isinstance(value, int): return value text = str(value).strip().lower() if text == "auto" or text.startswith("auto:"): return int(max_batch_size or 64) return int(text) def _dtype_from_name(value: str | torch.dtype | None) -> torch.dtype | None: if value is None or value == "auto": return None if isinstance(value, torch.dtype): return value normalized = str(value).replace("torch.", "").lower() if normalized in {"bf16", "bfloat16"}: return torch.bfloat16 if normalized in {"fp16", "float16", "half"}: return torch.float16 if normalized in {"fp32", "float32", "float"}: return torch.float32 raise ValueError(f"Unsupported dtype: {value!r}") @register_model("atom2.7m") class FusionGPTLM(LM): """Fusion-tokenizer GPT adapter for lm-eval log-likelihood tasks.""" def __init__( self, pretrained: str = "outputs/fusion_run/final_model", tokenizer_dir: str = "tokenizer_4k", batch_size: int | str | None = 1, max_batch_size: int | None = 64, max_length: int | None = None, device: str | None = "cuda", dtype: str | torch.dtype | None = "auto", mixed_precision_dtype: str | torch.dtype | None = "auto", trust_remote_code: bool | str | None = None, **_: Any, ) -> None: super().__init__() del trust_remote_code if device is None or device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" self._device = torch.device(device) self.batch_size = _parse_batch_size(batch_size, max_batch_size) self.tokenizer: FusionTokenizer = load_tokenizer(Path(tokenizer_dir)) self.model = AutoModelForCausalLM.from_pretrained( Path(pretrained), trust_remote_code=True, ).to(self.device) model_dtype = _dtype_from_name(dtype) if model_dtype is not None: self.model = self.model.to(dtype=model_dtype) if mixed_precision_dtype == "auto": self.mixed_precision_dtype = ( torch.bfloat16 if self.device.type == "cuda" else None ) else: self.mixed_precision_dtype = _dtype_from_name(mixed_precision_dtype) self.model.eval() self.max_length = int( max_length or getattr(self.model.config, "block_size", None) or getattr(self.model.config, "max_position_embeddings", 512) ) @property def eot_token_id(self) -> int: return EOT_ID def tok_encode( self, string: str, add_special_tokens: bool | None = None, left_truncate_len: int | None = None, **_: Any, ) -> list[int]: del add_special_tokens ids = self.tokenizer.encode(string).input_ids if left_truncate_len is not None: ids = ids[-left_truncate_len:] return ids def tok_decode(self, tokens, skip_special_tokens: bool = True) -> str: if isinstance(tokens, int): tokens = [tokens] return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) def _encode_request( self, context: str, continuation: str, ) -> tuple[list[int], list[int], list[int], list[int], int]: if context == "": continuation_encoding = self.tokenizer.encode(continuation) ids = [self.eot_token_id] + continuation_encoding.input_ids place_ids = [0] + continuation_encoding.place_ids role_ids = [0] + continuation_encoding.role_ids context_len = 1 continuation_ids = continuation_encoding.input_ids else: n_spaces = len(context) - len(context.rstrip()) if n_spaces > 0: continuation = context[-n_spaces:] + continuation context = context[:-n_spaces] full_encoding = self.tokenizer.encode(context + continuation) context_encoding = self.tokenizer.encode(context) ids = full_encoding.input_ids place_ids = full_encoding.place_ids role_ids = full_encoding.role_ids context_len = len(context_encoding.input_ids) continuation_ids = ids[context_len:] if not continuation_ids: raise ValueError("Continuation encoded to zero tokens") return ids, place_ids, role_ids, continuation_ids, context_len def loglikelihood( self, requests: list["Instance"], disable_tqdm: bool = False, ) -> list[tuple[float, bool]]: encoded = [ self._encode_request(context, continuation) for context, continuation in tqdm( [req.args for req in requests], desc="Fusion tokenizing inputs", disable=disable_tqdm, ) ] results: list[tuple[float, bool]] = [] for start in tqdm( range(0, len(encoded), self.batch_size), desc="Running fusion loglikelihood requests", disable=disable_tqdm or self.rank != 0, ): batch = encoded[start : start + self.batch_size] rows = [] row_places = [] row_roles = [] row_targets = [] row_score_slices = [] for ids, place_ids, role_ids, continuation_ids, context_len in batch: window_start = max(0, len(ids) - (self.max_length + 1)) window_ids = ids[window_start:] window_places = place_ids[window_start:] window_roles = role_ids[window_start:] input_ids = window_ids[:-1] targets = window_ids[1:] full_score_start = context_len - 1 full_score_end = len(ids) - 1 score_start = max(full_score_start, window_start) - window_start score_end = full_score_end - window_start if score_end <= score_start: raise ValueError("No continuation tokens remain after truncation") scored_continuation_ids = continuation_ids[-(score_end - score_start) :] rows.append(input_ids) row_places.append(window_places[:-1]) row_roles.append(window_roles[:-1]) row_targets.append(targets) row_score_slices.append((score_start, score_end, scored_continuation_ids)) max_len = max(len(row) for row in rows) input_tensor = torch.full( (len(rows), max_len), self.eot_token_id, dtype=torch.long, device=self.device, ) place_tensor = torch.zeros_like(input_tensor) role_tensor = torch.zeros_like(input_tensor) attention_mask = torch.zeros_like(input_tensor, dtype=torch.bool) target_tensor = torch.full_like(input_tensor, self.eot_token_id) for row, (ids, places, roles, targets) in enumerate( zip(rows, row_places, row_roles, row_targets, strict=True) ): length = len(ids) input_tensor[row, :length] = torch.tensor(ids, device=self.device) place_tensor[row, :length] = torch.tensor(places, device=self.device) role_tensor[row, :length] = torch.tensor(roles, device=self.device) target_tensor[row, :length] = torch.tensor(targets, device=self.device) attention_mask[row, :length] = True autocast = ( torch.autocast( device_type=self.device.type, dtype=self.mixed_precision_dtype, enabled=self.mixed_precision_dtype is not None, ) if self.device.type == "cuda" else nullcontext() ) with torch.inference_mode(), autocast: logits = self.model( input_ids=input_tensor, place_ids=place_tensor, role_ids=role_tensor, attention_mask=attention_mask, ).logits log_probs = F.log_softmax(logits.float(), dim=-1) for row, (score_start, score_end, continuation_ids) in enumerate(row_score_slices): row_log_probs = log_probs[row, score_start:score_end] row_targets_for_score = target_tensor[row, score_start:score_end] token_log_probs = torch.gather( row_log_probs, 1, row_targets_for_score.unsqueeze(-1), ).squeeze(-1) greedy = torch.equal( row_log_probs.argmax(dim=-1), torch.tensor(continuation_ids, dtype=torch.long, device=self.device), ) results.append((float(token_log_probs.sum().item()), bool(greedy))) return results def loglikelihood_rolling( self, requests: list["Instance"], disable_tqdm: bool = False, ) -> list[float]: results = [] for (text,) in tqdm( [req.args for req in requests], desc="Running fusion rolling loglikelihood", disable=disable_tqdm or self.rank != 0, ): encoding = self.tokenizer.encode(text) ids = encoding.input_ids places = encoding.place_ids roles = encoding.role_ids total = 0.0 start = 0 while start < len(ids): end = min(len(ids), start + self.max_length) prefix = [self.eot_token_id] if start == 0 else ids[start - 1 : start] chunk_ids = prefix + ids[start:end] chunk_places = [0] + places[start:end] if start == 0 else places[start - 1 : end] chunk_roles = [0] + roles[start:end] if start == 0 else roles[start - 1 : end] input_ids = torch.tensor([chunk_ids[:-1]], dtype=torch.long, device=self.device) place_ids = torch.tensor([chunk_places[:-1]], dtype=torch.long, device=self.device) role_ids = torch.tensor([chunk_roles[:-1]], dtype=torch.long, device=self.device) targets = torch.tensor(chunk_ids[1:], dtype=torch.long, device=self.device) with torch.inference_mode(): logits = self.model( input_ids=input_ids, place_ids=place_ids, role_ids=role_ids, ).logits[0] log_probs = F.log_softmax(logits.float(), dim=-1) total += float( torch.gather(log_probs, 1, targets.unsqueeze(-1)).sum().item() ) start = end results.append(total) return results def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: raise NotImplementedError( "FusionGPTLM currently supports loglikelihood tasks. " "Use tasks with multiple-choice/loglikelihood output." )