| """lm-eval wrapper for Atom2.7m checkpoints. |
| |
| The standard ``hf`` lm-eval model does not use the fusion tokenizer wrapper and |
| does not pass arithmetic feature streams. This model keeps lm-eval's |
| log-likelihood interface while encoding with ``tokenizer_utils.load_tokenizer`` |
| and forwarding ``place_ids`` and ``role_ids``. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from contextlib import nullcontext |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
| import torch.nn.functional as F |
| from lm_eval.api.model import LM |
| from lm_eval.api.registry import register_model |
| from tqdm import tqdm |
| from transformers import AutoModelForCausalLM |
|
|
| from tokenizer_utils import EOT_ID, FusionTokenizer, load_tokenizer |
|
|
|
|
| def _parse_bool(value: Any, default: bool = False) -> bool: |
| if value is None: |
| return default |
| if isinstance(value, bool): |
| return value |
| return str(value).strip().lower() in {"1", "true", "yes", "on"} |
|
|
|
|
| def _parse_batch_size(value: int | str | None, max_batch_size: int | None) -> int: |
| if value is None: |
| return 1 |
| if isinstance(value, int): |
| return value |
| text = str(value).strip().lower() |
| if text == "auto" or text.startswith("auto:"): |
| return int(max_batch_size or 64) |
| return int(text) |
|
|
|
|
| def _dtype_from_name(value: str | torch.dtype | None) -> torch.dtype | None: |
| if value is None or value == "auto": |
| return None |
| if isinstance(value, torch.dtype): |
| return value |
| normalized = str(value).replace("torch.", "").lower() |
| if normalized in {"bf16", "bfloat16"}: |
| return torch.bfloat16 |
| if normalized in {"fp16", "float16", "half"}: |
| return torch.float16 |
| if normalized in {"fp32", "float32", "float"}: |
| return torch.float32 |
| raise ValueError(f"Unsupported dtype: {value!r}") |
|
|
|
|
| @register_model("atom2.7m") |
| class FusionGPTLM(LM): |
| """Fusion-tokenizer GPT adapter for lm-eval log-likelihood tasks.""" |
|
|
| def __init__( |
| self, |
| pretrained: str = "outputs/fusion_run/final_model", |
| tokenizer_dir: str = "tokenizer_4k", |
| batch_size: int | str | None = 1, |
| max_batch_size: int | None = 64, |
| max_length: int | None = None, |
| device: str | None = "cuda", |
| dtype: str | torch.dtype | None = "auto", |
| mixed_precision_dtype: str | torch.dtype | None = "auto", |
| trust_remote_code: bool | str | None = None, |
| **_: Any, |
| ) -> None: |
| super().__init__() |
| del trust_remote_code |
| if device is None or device == "auto": |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| self._device = torch.device(device) |
| self.batch_size = _parse_batch_size(batch_size, max_batch_size) |
| self.tokenizer: FusionTokenizer = load_tokenizer(Path(tokenizer_dir)) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| Path(pretrained), |
| trust_remote_code=True, |
| ).to(self.device) |
| model_dtype = _dtype_from_name(dtype) |
| if model_dtype is not None: |
| self.model = self.model.to(dtype=model_dtype) |
| if mixed_precision_dtype == "auto": |
| self.mixed_precision_dtype = ( |
| torch.bfloat16 if self.device.type == "cuda" else None |
| ) |
| else: |
| self.mixed_precision_dtype = _dtype_from_name(mixed_precision_dtype) |
| self.model.eval() |
| self.max_length = int( |
| max_length |
| or getattr(self.model.config, "block_size", None) |
| or getattr(self.model.config, "max_position_embeddings", 512) |
| ) |
|
|
| @property |
| def eot_token_id(self) -> int: |
| return EOT_ID |
|
|
| def tok_encode( |
| self, |
| string: str, |
| add_special_tokens: bool | None = None, |
| left_truncate_len: int | None = None, |
| **_: Any, |
| ) -> list[int]: |
| del add_special_tokens |
| ids = self.tokenizer.encode(string).input_ids |
| if left_truncate_len is not None: |
| ids = ids[-left_truncate_len:] |
| return ids |
|
|
| def tok_decode(self, tokens, skip_special_tokens: bool = True) -> str: |
| if isinstance(tokens, int): |
| tokens = [tokens] |
| return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) |
|
|
| def _encode_request( |
| self, |
| context: str, |
| continuation: str, |
| ) -> tuple[list[int], list[int], list[int], list[int], int]: |
| if context == "": |
| continuation_encoding = self.tokenizer.encode(continuation) |
| ids = [self.eot_token_id] + continuation_encoding.input_ids |
| place_ids = [0] + continuation_encoding.place_ids |
| role_ids = [0] + continuation_encoding.role_ids |
| context_len = 1 |
| continuation_ids = continuation_encoding.input_ids |
| else: |
| n_spaces = len(context) - len(context.rstrip()) |
| if n_spaces > 0: |
| continuation = context[-n_spaces:] + continuation |
| context = context[:-n_spaces] |
| full_encoding = self.tokenizer.encode(context + continuation) |
| context_encoding = self.tokenizer.encode(context) |
| ids = full_encoding.input_ids |
| place_ids = full_encoding.place_ids |
| role_ids = full_encoding.role_ids |
| context_len = len(context_encoding.input_ids) |
| continuation_ids = ids[context_len:] |
|
|
| if not continuation_ids: |
| raise ValueError("Continuation encoded to zero tokens") |
| return ids, place_ids, role_ids, continuation_ids, context_len |
|
|
| def loglikelihood( |
| self, |
| requests: list["Instance"], |
| disable_tqdm: bool = False, |
| ) -> list[tuple[float, bool]]: |
| encoded = [ |
| self._encode_request(context, continuation) |
| for context, continuation in tqdm( |
| [req.args for req in requests], |
| desc="Fusion tokenizing inputs", |
| disable=disable_tqdm, |
| ) |
| ] |
| results: list[tuple[float, bool]] = [] |
| for start in tqdm( |
| range(0, len(encoded), self.batch_size), |
| desc="Running fusion loglikelihood requests", |
| disable=disable_tqdm or self.rank != 0, |
| ): |
| batch = encoded[start : start + self.batch_size] |
| rows = [] |
| row_places = [] |
| row_roles = [] |
| row_targets = [] |
| row_score_slices = [] |
| for ids, place_ids, role_ids, continuation_ids, context_len in batch: |
| window_start = max(0, len(ids) - (self.max_length + 1)) |
| window_ids = ids[window_start:] |
| window_places = place_ids[window_start:] |
| window_roles = role_ids[window_start:] |
| input_ids = window_ids[:-1] |
| targets = window_ids[1:] |
| full_score_start = context_len - 1 |
| full_score_end = len(ids) - 1 |
| score_start = max(full_score_start, window_start) - window_start |
| score_end = full_score_end - window_start |
| if score_end <= score_start: |
| raise ValueError("No continuation tokens remain after truncation") |
| scored_continuation_ids = continuation_ids[-(score_end - score_start) :] |
| rows.append(input_ids) |
| row_places.append(window_places[:-1]) |
| row_roles.append(window_roles[:-1]) |
| row_targets.append(targets) |
| row_score_slices.append((score_start, score_end, scored_continuation_ids)) |
|
|
| max_len = max(len(row) for row in rows) |
| input_tensor = torch.full( |
| (len(rows), max_len), |
| self.eot_token_id, |
| dtype=torch.long, |
| device=self.device, |
| ) |
| place_tensor = torch.zeros_like(input_tensor) |
| role_tensor = torch.zeros_like(input_tensor) |
| attention_mask = torch.zeros_like(input_tensor, dtype=torch.bool) |
| target_tensor = torch.full_like(input_tensor, self.eot_token_id) |
| for row, (ids, places, roles, targets) in enumerate( |
| zip(rows, row_places, row_roles, row_targets, strict=True) |
| ): |
| length = len(ids) |
| input_tensor[row, :length] = torch.tensor(ids, device=self.device) |
| place_tensor[row, :length] = torch.tensor(places, device=self.device) |
| role_tensor[row, :length] = torch.tensor(roles, device=self.device) |
| target_tensor[row, :length] = torch.tensor(targets, device=self.device) |
| attention_mask[row, :length] = True |
|
|
| autocast = ( |
| torch.autocast( |
| device_type=self.device.type, |
| dtype=self.mixed_precision_dtype, |
| enabled=self.mixed_precision_dtype is not None, |
| ) |
| if self.device.type == "cuda" |
| else nullcontext() |
| ) |
| with torch.inference_mode(), autocast: |
| logits = self.model( |
| input_ids=input_tensor, |
| place_ids=place_tensor, |
| role_ids=role_tensor, |
| attention_mask=attention_mask, |
| ).logits |
| log_probs = F.log_softmax(logits.float(), dim=-1) |
|
|
| for row, (score_start, score_end, continuation_ids) in enumerate(row_score_slices): |
| row_log_probs = log_probs[row, score_start:score_end] |
| row_targets_for_score = target_tensor[row, score_start:score_end] |
| token_log_probs = torch.gather( |
| row_log_probs, |
| 1, |
| row_targets_for_score.unsqueeze(-1), |
| ).squeeze(-1) |
| greedy = torch.equal( |
| row_log_probs.argmax(dim=-1), |
| torch.tensor(continuation_ids, dtype=torch.long, device=self.device), |
| ) |
| results.append((float(token_log_probs.sum().item()), bool(greedy))) |
|
|
| return results |
|
|
| def loglikelihood_rolling( |
| self, |
| requests: list["Instance"], |
| disable_tqdm: bool = False, |
| ) -> list[float]: |
| results = [] |
| for (text,) in tqdm( |
| [req.args for req in requests], |
| desc="Running fusion rolling loglikelihood", |
| disable=disable_tqdm or self.rank != 0, |
| ): |
| encoding = self.tokenizer.encode(text) |
| ids = encoding.input_ids |
| places = encoding.place_ids |
| roles = encoding.role_ids |
| total = 0.0 |
| start = 0 |
| while start < len(ids): |
| end = min(len(ids), start + self.max_length) |
| prefix = [self.eot_token_id] if start == 0 else ids[start - 1 : start] |
| chunk_ids = prefix + ids[start:end] |
| chunk_places = [0] + places[start:end] if start == 0 else places[start - 1 : end] |
| chunk_roles = [0] + roles[start:end] if start == 0 else roles[start - 1 : end] |
| input_ids = torch.tensor([chunk_ids[:-1]], dtype=torch.long, device=self.device) |
| place_ids = torch.tensor([chunk_places[:-1]], dtype=torch.long, device=self.device) |
| role_ids = torch.tensor([chunk_roles[:-1]], dtype=torch.long, device=self.device) |
| targets = torch.tensor(chunk_ids[1:], dtype=torch.long, device=self.device) |
| with torch.inference_mode(): |
| logits = self.model( |
| input_ids=input_ids, |
| place_ids=place_ids, |
| role_ids=role_ids, |
| ).logits[0] |
| log_probs = F.log_softmax(logits.float(), dim=-1) |
| total += float( |
| torch.gather(log_probs, 1, targets.unsqueeze(-1)).sum().item() |
| ) |
| start = end |
| results.append(total) |
| return results |
|
|
| def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: |
| raise NotImplementedError( |
| "FusionGPTLM currently supports loglikelihood tasks. " |
| "Use tasks with multiple-choice/loglikelihood output." |
| ) |
|
|