Atom2.7m / lm_eval_fusion.py
ucr-max's picture
Upload Atom2.7m model
271e253 verified
Raw
History Blame Contribute Delete
12.3 kB
"""lm-eval wrapper for Atom2.7m checkpoints.
The standard ``hf`` lm-eval model does not use the fusion tokenizer wrapper and
does not pass arithmetic feature streams. This model keeps lm-eval's
log-likelihood interface while encoding with ``tokenizer_utils.load_tokenizer``
and forwarding ``place_ids`` and ``role_ids``.
"""
from __future__ import annotations
from contextlib import nullcontext
from pathlib import Path
from typing import Any
import torch
import torch.nn.functional as F
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from tokenizer_utils import EOT_ID, FusionTokenizer, load_tokenizer
def _parse_bool(value: Any, default: bool = False) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _parse_batch_size(value: int | str | None, max_batch_size: int | None) -> int:
if value is None:
return 1
if isinstance(value, int):
return value
text = str(value).strip().lower()
if text == "auto" or text.startswith("auto:"):
return int(max_batch_size or 64)
return int(text)
def _dtype_from_name(value: str | torch.dtype | None) -> torch.dtype | None:
if value is None or value == "auto":
return None
if isinstance(value, torch.dtype):
return value
normalized = str(value).replace("torch.", "").lower()
if normalized in {"bf16", "bfloat16"}:
return torch.bfloat16
if normalized in {"fp16", "float16", "half"}:
return torch.float16
if normalized in {"fp32", "float32", "float"}:
return torch.float32
raise ValueError(f"Unsupported dtype: {value!r}")
@register_model("atom2.7m")
class FusionGPTLM(LM):
"""Fusion-tokenizer GPT adapter for lm-eval log-likelihood tasks."""
def __init__(
self,
pretrained: str = "outputs/fusion_run/final_model",
tokenizer_dir: str = "tokenizer_4k",
batch_size: int | str | None = 1,
max_batch_size: int | None = 64,
max_length: int | None = None,
device: str | None = "cuda",
dtype: str | torch.dtype | None = "auto",
mixed_precision_dtype: str | torch.dtype | None = "auto",
trust_remote_code: bool | str | None = None,
**_: Any,
) -> None:
super().__init__()
del trust_remote_code
if device is None or device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
self._device = torch.device(device)
self.batch_size = _parse_batch_size(batch_size, max_batch_size)
self.tokenizer: FusionTokenizer = load_tokenizer(Path(tokenizer_dir))
self.model = AutoModelForCausalLM.from_pretrained(
Path(pretrained),
trust_remote_code=True,
).to(self.device)
model_dtype = _dtype_from_name(dtype)
if model_dtype is not None:
self.model = self.model.to(dtype=model_dtype)
if mixed_precision_dtype == "auto":
self.mixed_precision_dtype = (
torch.bfloat16 if self.device.type == "cuda" else None
)
else:
self.mixed_precision_dtype = _dtype_from_name(mixed_precision_dtype)
self.model.eval()
self.max_length = int(
max_length
or getattr(self.model.config, "block_size", None)
or getattr(self.model.config, "max_position_embeddings", 512)
)
@property
def eot_token_id(self) -> int:
return EOT_ID
def tok_encode(
self,
string: str,
add_special_tokens: bool | None = None,
left_truncate_len: int | None = None,
**_: Any,
) -> list[int]:
del add_special_tokens
ids = self.tokenizer.encode(string).input_ids
if left_truncate_len is not None:
ids = ids[-left_truncate_len:]
return ids
def tok_decode(self, tokens, skip_special_tokens: bool = True) -> str:
if isinstance(tokens, int):
tokens = [tokens]
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def _encode_request(
self,
context: str,
continuation: str,
) -> tuple[list[int], list[int], list[int], list[int], int]:
if context == "":
continuation_encoding = self.tokenizer.encode(continuation)
ids = [self.eot_token_id] + continuation_encoding.input_ids
place_ids = [0] + continuation_encoding.place_ids
role_ids = [0] + continuation_encoding.role_ids
context_len = 1
continuation_ids = continuation_encoding.input_ids
else:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
full_encoding = self.tokenizer.encode(context + continuation)
context_encoding = self.tokenizer.encode(context)
ids = full_encoding.input_ids
place_ids = full_encoding.place_ids
role_ids = full_encoding.role_ids
context_len = len(context_encoding.input_ids)
continuation_ids = ids[context_len:]
if not continuation_ids:
raise ValueError("Continuation encoded to zero tokens")
return ids, place_ids, role_ids, continuation_ids, context_len
def loglikelihood(
self,
requests: list["Instance"],
disable_tqdm: bool = False,
) -> list[tuple[float, bool]]:
encoded = [
self._encode_request(context, continuation)
for context, continuation in tqdm(
[req.args for req in requests],
desc="Fusion tokenizing inputs",
disable=disable_tqdm,
)
]
results: list[tuple[float, bool]] = []
for start in tqdm(
range(0, len(encoded), self.batch_size),
desc="Running fusion loglikelihood requests",
disable=disable_tqdm or self.rank != 0,
):
batch = encoded[start : start + self.batch_size]
rows = []
row_places = []
row_roles = []
row_targets = []
row_score_slices = []
for ids, place_ids, role_ids, continuation_ids, context_len in batch:
window_start = max(0, len(ids) - (self.max_length + 1))
window_ids = ids[window_start:]
window_places = place_ids[window_start:]
window_roles = role_ids[window_start:]
input_ids = window_ids[:-1]
targets = window_ids[1:]
full_score_start = context_len - 1
full_score_end = len(ids) - 1
score_start = max(full_score_start, window_start) - window_start
score_end = full_score_end - window_start
if score_end <= score_start:
raise ValueError("No continuation tokens remain after truncation")
scored_continuation_ids = continuation_ids[-(score_end - score_start) :]
rows.append(input_ids)
row_places.append(window_places[:-1])
row_roles.append(window_roles[:-1])
row_targets.append(targets)
row_score_slices.append((score_start, score_end, scored_continuation_ids))
max_len = max(len(row) for row in rows)
input_tensor = torch.full(
(len(rows), max_len),
self.eot_token_id,
dtype=torch.long,
device=self.device,
)
place_tensor = torch.zeros_like(input_tensor)
role_tensor = torch.zeros_like(input_tensor)
attention_mask = torch.zeros_like(input_tensor, dtype=torch.bool)
target_tensor = torch.full_like(input_tensor, self.eot_token_id)
for row, (ids, places, roles, targets) in enumerate(
zip(rows, row_places, row_roles, row_targets, strict=True)
):
length = len(ids)
input_tensor[row, :length] = torch.tensor(ids, device=self.device)
place_tensor[row, :length] = torch.tensor(places, device=self.device)
role_tensor[row, :length] = torch.tensor(roles, device=self.device)
target_tensor[row, :length] = torch.tensor(targets, device=self.device)
attention_mask[row, :length] = True
autocast = (
torch.autocast(
device_type=self.device.type,
dtype=self.mixed_precision_dtype,
enabled=self.mixed_precision_dtype is not None,
)
if self.device.type == "cuda"
else nullcontext()
)
with torch.inference_mode(), autocast:
logits = self.model(
input_ids=input_tensor,
place_ids=place_tensor,
role_ids=role_tensor,
attention_mask=attention_mask,
).logits
log_probs = F.log_softmax(logits.float(), dim=-1)
for row, (score_start, score_end, continuation_ids) in enumerate(row_score_slices):
row_log_probs = log_probs[row, score_start:score_end]
row_targets_for_score = target_tensor[row, score_start:score_end]
token_log_probs = torch.gather(
row_log_probs,
1,
row_targets_for_score.unsqueeze(-1),
).squeeze(-1)
greedy = torch.equal(
row_log_probs.argmax(dim=-1),
torch.tensor(continuation_ids, dtype=torch.long, device=self.device),
)
results.append((float(token_log_probs.sum().item()), bool(greedy)))
return results
def loglikelihood_rolling(
self,
requests: list["Instance"],
disable_tqdm: bool = False,
) -> list[float]:
results = []
for (text,) in tqdm(
[req.args for req in requests],
desc="Running fusion rolling loglikelihood",
disable=disable_tqdm or self.rank != 0,
):
encoding = self.tokenizer.encode(text)
ids = encoding.input_ids
places = encoding.place_ids
roles = encoding.role_ids
total = 0.0
start = 0
while start < len(ids):
end = min(len(ids), start + self.max_length)
prefix = [self.eot_token_id] if start == 0 else ids[start - 1 : start]
chunk_ids = prefix + ids[start:end]
chunk_places = [0] + places[start:end] if start == 0 else places[start - 1 : end]
chunk_roles = [0] + roles[start:end] if start == 0 else roles[start - 1 : end]
input_ids = torch.tensor([chunk_ids[:-1]], dtype=torch.long, device=self.device)
place_ids = torch.tensor([chunk_places[:-1]], dtype=torch.long, device=self.device)
role_ids = torch.tensor([chunk_roles[:-1]], dtype=torch.long, device=self.device)
targets = torch.tensor(chunk_ids[1:], dtype=torch.long, device=self.device)
with torch.inference_mode():
logits = self.model(
input_ids=input_ids,
place_ids=place_ids,
role_ids=role_ids,
).logits[0]
log_probs = F.log_softmax(logits.float(), dim=-1)
total += float(
torch.gather(log_probs, 1, targets.unsqueeze(-1)).sum().item()
)
start = end
results.append(total)
return results
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
raise NotImplementedError(
"FusionGPTLM currently supports loglikelihood tasks. "
"Use tasks with multiple-choice/loglikelihood output."
)