new-language_model / utils /data_utils.py
Raja-65's picture
Add ELF demo
6ab8280
Raw
History Blame Contribute Delete
7.57 kB
import json
from typing import Dict, Optional
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from utils.encoder_utils import build_self_attn_cond_masks
from utils.logging_utils import log_for_0
def _process_count() -> int:
try:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
except Exception:
pass
return 1
def _process_index() -> int:
try:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
return dist.get_rank()
except Exception:
pass
return 0
def get_pad_token_id(tokenizer, pad_token: str = "pad") -> int:
"""Resolve the token id used for padding, optionally using EOS as pad."""
token_id = tokenizer.eos_token_id if pad_token == "eos" else tokenizer.pad_token_id
if token_id is None:
raise ValueError("Tokenizer has no pad_token_id or eos_token_id.")
return token_id
def prepare_batch(batch: Dict, config, generator: torch.Generator) -> Dict:
"""Convert numpy batch to torch tensors and sample label-drop decisions."""
result = {}
for k, v in batch.items():
if isinstance(v, np.ndarray):
result[k] = torch.from_numpy(v)
elif isinstance(v, torch.Tensor):
result[k] = v
else:
result[k] = v
batch_size = result["input_ids"].shape[0]
label_drop_mask = torch.zeros((batch_size,), dtype=torch.bool)
if config.label_drop_prob > 0:
u = torch.rand((batch_size,), generator=generator)
label_drop_mask = u < config.label_drop_prob
result["label_drop_mask"] = label_drop_mask
return result
def pad_and_truncate(ids_list, target_len, pad_token_id):
"""Pad or truncate sequences to target_len, return stacked array and lengths."""
padded, lengths = [], []
for ids in ids_list:
orig_len = min(len(ids), target_len)
ids = ids[:target_len]
if orig_len < target_len:
ids = np.concatenate([ids, np.full(target_len - orig_len, pad_token_id, dtype=ids.dtype)])
padded.append(ids)
lengths.append(orig_len)
return np.stack(padded), np.array(lengths)
def get_dataloader(
dataset,
batch_size: int,
shuffle: bool = True,
num_workers: int = 0,
drop_last: bool = True,
max_seq_length: int = 512,
pad_token_id: int = 0,
max_input_seq_length: Optional[int] = None,
distributed: bool = True,
):
"""Create a DataLoader."""
def collate_fn(batch_list):
input_ids_list = [np.array(item["input_ids"]) for item in batch_list]
if "condition_input_ids" in batch_list[0]:
seq_list, cond_lens = [], []
for item in batch_list:
cond = np.array(item["condition_input_ids"])[:max_input_seq_length]
inp = np.array(item["input_ids"])
seq_list.append(np.concatenate([cond, inp]))
cond_lens.append(len(cond))
cond_lens = np.array(cond_lens)
else:
seq_list = input_ids_list
cond_lens = np.zeros(len(input_ids_list), dtype=np.int32)
ids, total_lens = pad_and_truncate(seq_list, max_seq_length, pad_token_id)
pos = np.arange(max_seq_length)[None, :]
is_cond = pos < cond_lens[:, None]
is_valid = pos < total_lens[:, None]
encoder_attn, attn, pred = build_self_attn_cond_masks(is_cond, is_valid, xp=np)
result = {
"input_ids": ids,
"encoder_attention_mask": encoder_attn,
"attention_mask": attn,
"cond_seq_mask": pred,
}
for key in ("index", "input", "target"):
if key in batch_list[0]:
result[key] = [item[key] for item in batch_list]
return result
common = dict(
batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn,
drop_last=drop_last, persistent_workers=num_workers > 0,
pin_memory=True,
)
if distributed:
sampler = DistributedSampler(
dataset, num_replicas=_process_count(), rank=_process_index(),
shuffle=shuffle, drop_last=drop_last,
)
return DataLoader(dataset, sampler=sampler, **common)
return DataLoader(dataset, shuffle=shuffle, **common)
def load_jsonl_dataset(path, tokenizer, input_key="input", output_key="output"):
"""Load a JSONL eval set (one `{input, output}` example per line)."""
examples = []
with open(path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
line = line.strip()
if not line:
continue
data = json.loads(line)
examples.append({
"index": i,
"input": data[input_key],
"target": data[output_key],
"condition_input_ids": tokenizer(data[input_key], add_special_tokens=False)["input_ids"],
"input_ids": tokenizer(data[output_key], add_special_tokens=False)["input_ids"],
})
return examples
# ============================================
# Dataset loading
# ============================================
def _looks_like_save_to_disk_arrow(ds) -> bool:
"""Detect HF datasets uploaded via `save_to_disk` (returns 1-row of metadata)."""
return (
len(ds) == 1
and any(c.startswith("_") for c in ds.column_names)
and not any(not c.startswith("_") for c in ds.column_names)
)
def load_dataset_split(path: str, dataset_cache_dir=None):
"""Load a dataset. Tries HuggingFace Hub first; falls back to local on-disk Arrow."""
from datasets import DatasetDict, load_dataset as hf_load_dataset, load_from_disk
ds = None
try:
ds = hf_load_dataset(path, cache_dir=dataset_cache_dir)
except Exception:
ds = load_from_disk(path)
if isinstance(ds, DatasetDict):
splits = list(ds.keys())
if len(splits) != 1:
raise ValueError(f"Expected dataset at {path!r} to have a single split, got {splits}.")
ds = ds[splits[0]]
if _looks_like_save_to_disk_arrow(ds):
from huggingface_hub import snapshot_download
log_for_0(
f"Dataset at {path!r} looks like a save_to_disk-format HF repo; "
f"re-downloading via snapshot_download + load_from_disk."
)
local_dir = snapshot_download(repo_id=path, repo_type="dataset", cache_dir=dataset_cache_dir)
ds = load_from_disk(local_dir)
if isinstance(ds, DatasetDict):
splits = list(ds.keys())
if len(splits) != 1:
raise ValueError(f"Expected dataset at {path!r} to have a single split, got {splits}.")
ds = ds[splits[0]]
ds.set_format(type="numpy", columns=ds.column_names)
return ds
def load_dataset(config, dataset_cache_dir=None):
"""Resolve config.data_path / config.eval_data_path into train/eval datasets."""
log_for_0(f"Loading dataset from {config.data_path}...")
train_dataset = load_dataset_split(config.data_path, dataset_cache_dir)
log_for_0(f"Train size: {len(train_dataset)}")
eval_dataset = None
if config.eval_data_path:
eval_dataset = load_dataset_split(config.eval_data_path, dataset_cache_dir)
log_for_0(f"Eval size: {len(eval_dataset)}")
else:
log_for_0("No eval dataset")
return train_dataset, eval_dataset