temp_ss / src /common_lm_data.py
LJYAI's picture
upload src
2c44909 verified
#!/usr/bin/env python3
"""Shared LM dataset helpers for fair cross-method comparisons."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Iterable, Iterator, List, Optional, Tuple
import torch
try:
from datasets import load_dataset
from datasets import Dataset as HFDataset
except Exception: # pragma: no cover - optional dependency
load_dataset = None
HFDataset = None
def _normalize_config(config: Optional[str]) -> Optional[str]:
if config is None:
return None
if config.strip().lower() in {"none", "null", "-"}:
return None
return config
def guess_text_field(dataset) -> str:
if hasattr(dataset, "column_names") and dataset.column_names:
if "text" in dataset.column_names:
return "text"
return dataset.column_names[0]
if hasattr(dataset, "features"):
names = list(dataset.features.keys())
if "text" in names:
return "text"
if names:
return names[0]
return "text"
def normalize_dataset_name(name: str) -> str:
normalized = name.strip().lower()
aliases = {
"bookcorpus": "bookcorpus",
"boockcorpus": "bookcorpus",
"slimpajama": "slimpajama",
"dkyoon/slimpajama-6b": "slimpajama",
}
if normalized not in aliases:
raise ValueError(f"Unsupported dataset: {name}")
return aliases[normalized]
def resolve_dataset_spec(
name: str,
config: Optional[str] = None,
split: str = "train",
) -> Tuple[str, Optional[str], str]:
normalized = normalize_dataset_name(name)
if normalized == "bookcorpus":
return "bookcorpus", _normalize_config(config), split
if normalized == "slimpajama":
return "DKYoon/SlimPajama-6B", _normalize_config(config), split
raise ValueError(f"Unsupported dataset: {name}")
def _sample_dataset_rows(dataset, target: int, seed: int) -> List[Dict[str, object]]:
if target <= 0:
return []
try:
dataset = dataset.shuffle(seed=seed)
except Exception:
pass
if hasattr(dataset, "__len__"):
limit = min(target, len(dataset))
dataset = dataset.select(range(limit))
return [row for row in dataset]
rows = []
for row in dataset:
rows.append(row)
if len(rows) >= target:
break
return rows
def _iter_dataset_rows(dataset, seed: int) -> Iterator[Dict[str, object]]:
try:
dataset = dataset.shuffle(seed=seed)
except Exception:
pass
for row in dataset:
yield row
def load_named_texts(
dataset_name: str,
*,
config: Optional[str] = None,
split: str = "train",
text_field: Optional[str] = None,
num_samples: int = 0,
seed: int = 0,
) -> List[str]:
if load_dataset is None:
raise SystemExit("datasets is required for shared LM dataloaders")
hf_name, hf_config, hf_split = resolve_dataset_spec(dataset_name, config, split)
dataset = load_dataset(
hf_name,
hf_config,
split=hf_split,
trust_remote_code=True,
)
rows = dataset if num_samples <= 0 else _sample_dataset_rows(dataset, num_samples, seed)
field = text_field or guess_text_field(dataset)
texts: List[str] = []
for row in rows:
value = row.get(field, None) if isinstance(row, dict) else None
if isinstance(value, str) and value.strip():
texts.append(value)
return texts
def build_token_chunks_from_rows(
rows: Iterable[Dict[str, object]],
*,
text_field: str,
tokenizer,
seq_len: int,
num_sequences: int = 0,
add_bos: bool = False,
max_rows: int = 0,
) -> List[torch.Tensor]:
chunks: List[torch.Tensor] = []
buffer: List[int] = []
limit = None if num_sequences <= 0 else num_sequences
rows_seen = 0
for row in rows:
if max_rows > 0 and rows_seen >= max_rows:
break
rows_seen += 1
value = row.get(text_field, None) if isinstance(row, dict) else None
if not isinstance(value, str) or not value.strip():
continue
ids = tokenizer.encode(value, add_special_tokens=False)
if add_bos and tokenizer.bos_token_id is not None:
ids = [tokenizer.bos_token_id] + ids
if not ids:
continue
buffer.extend(ids)
while len(buffer) >= seq_len and (limit is None or len(chunks) < limit):
chunk = buffer[:seq_len]
buffer = buffer[seq_len:]
chunks.append(torch.tensor(chunk, dtype=torch.long))
if limit is not None and len(chunks) >= limit:
break
return chunks
def collect_texts_from_rows(
rows: Iterable[Dict[str, object]],
*,
text_field: str,
tokenizer,
target_tokens: int = 0,
add_bos: bool = False,
max_rows: int = 0,
) -> List[str]:
texts: List[str] = []
token_count = 0
rows_seen = 0
for row in rows:
if max_rows > 0 and rows_seen >= max_rows:
break
rows_seen += 1
value = row.get(text_field, None) if isinstance(row, dict) else None
if not isinstance(value, str) or not value.strip():
continue
texts.append(value)
if target_tokens > 0:
ids = tokenizer.encode(value, add_special_tokens=False)
if add_bos and tokenizer.bos_token_id is not None:
ids = [tokenizer.bos_token_id] + ids
token_count += len(ids)
if token_count >= target_tokens:
break
return texts
def build_token_chunks(
texts: Iterable[str],
tokenizer,
seq_len: int,
num_sequences: int = 0,
add_bos: bool = False,
) -> List[torch.Tensor]:
chunks: List[torch.Tensor] = []
buffer: List[int] = []
limit = None if num_sequences <= 0 else num_sequences
for text in texts:
ids = tokenizer.encode(text, add_special_tokens=False)
if add_bos and tokenizer.bos_token_id is not None:
ids = [tokenizer.bos_token_id] + ids
if not ids:
continue
buffer.extend(ids)
while len(buffer) >= seq_len and (limit is None or len(chunks) < limit):
chunk = buffer[:seq_len]
buffer = buffer[seq_len:]
chunks.append(torch.tensor(chunk, dtype=torch.long))
if limit is not None and len(chunks) >= limit:
break
return chunks
class TokenChunkDataset(torch.utils.data.Dataset):
def __init__(self, chunks: List[torch.Tensor]) -> None:
self.chunks = chunks
def __len__(self) -> int:
return len(self.chunks)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
input_ids = self.chunks[idx]
attention_mask = torch.ones_like(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids.clone(),
}
class TokenOnlyDataset(torch.utils.data.Dataset):
def __init__(self, chunks: List[torch.Tensor]) -> None:
self.chunks = chunks
def __len__(self) -> int:
return len(self.chunks)
def __getitem__(self, idx: int) -> torch.Tensor:
return self.chunks[idx]
class TokenInputMaskDataset(torch.utils.data.Dataset):
def __init__(self, chunks: List[torch.Tensor]) -> None:
self.chunks = chunks
def __len__(self) -> int:
return len(self.chunks)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
input_ids = self.chunks[idx]
return {
"input_ids": input_ids,
"attention_mask": torch.ones_like(input_ids),
}
@dataclass
class SharedLMDataSpec:
dataset: str
config: Optional[str] = None
split: str = "train"
text_field: Optional[str] = None
num_samples: int = 0
seq_len: int = 2048
num_sequences: int = 0
target_tokens: int = 0
batch_size: int = 1
shuffle: bool = False
num_workers: int = 0
seed: int = 0
add_bos: bool = False
def build_chunks(spec: SharedLMDataSpec, tokenizer) -> List[torch.Tensor]:
if load_dataset is None:
raise SystemExit("datasets is required for shared LM dataloaders")
hf_name, hf_config, hf_split = resolve_dataset_spec(spec.dataset, spec.config, spec.split)
dataset = load_dataset(
hf_name,
hf_config,
split=hf_split,
trust_remote_code=True,
)
target_sequences = spec.num_sequences
if spec.target_tokens > 0:
token_sequences = (spec.target_tokens + spec.seq_len - 1) // spec.seq_len
target_sequences = max(target_sequences, token_sequences)
row_limit = spec.num_samples if target_sequences <= 0 else 0
rows = _iter_dataset_rows(dataset, spec.seed)
text_field = spec.text_field or guess_text_field(dataset)
chunks = build_token_chunks_from_rows(
rows,
text_field=text_field,
tokenizer=tokenizer,
seq_len=spec.seq_len,
num_sequences=target_sequences,
add_bos=spec.add_bos,
max_rows=row_limit,
)
return chunks
def build_dataloader(spec: SharedLMDataSpec, tokenizer) -> torch.utils.data.DataLoader:
chunks = build_chunks(spec, tokenizer)
dataset = TokenChunkDataset(chunks)
return torch.utils.data.DataLoader(
dataset,
batch_size=spec.batch_size,
shuffle=spec.shuffle,
num_workers=spec.num_workers,
)
def build_text_dataloader(spec: SharedLMDataSpec, tokenizer) -> torch.utils.data.DataLoader:
if load_dataset is None:
raise SystemExit("datasets is required for shared LM dataloaders")
hf_name, hf_config, hf_split = resolve_dataset_spec(spec.dataset, spec.config, spec.split)
dataset = load_dataset(
hf_name,
hf_config,
split=hf_split,
trust_remote_code=True,
)
rows = _iter_dataset_rows(dataset, spec.seed)
text_field = spec.text_field or guess_text_field(dataset)
row_limit = spec.num_samples
texts = collect_texts_from_rows(
rows,
text_field=text_field,
tokenizer=tokenizer,
target_tokens=spec.target_tokens,
add_bos=spec.add_bos,
max_rows=row_limit,
)
return torch.utils.data.DataLoader(
texts,
batch_size=spec.batch_size,
shuffle=spec.shuffle,
num_workers=spec.num_workers,
drop_last=True,
)
def build_uidl_post_train_dataloader(
spec: SharedLMDataSpec,
tokenizer,
) -> torch.utils.data.DataLoader:
dataset = TokenChunkDataset(build_chunks(spec, tokenizer))
return torch.utils.data.DataLoader(
dataset,
batch_size=spec.batch_size,
shuffle=spec.shuffle,
num_workers=spec.num_workers,
)
def build_uidl_similarity_dataloader(
spec: SharedLMDataSpec,
tokenizer,
) -> torch.utils.data.DataLoader:
dataset = TokenInputMaskDataset(build_chunks(spec, tokenizer))
return torch.utils.data.DataLoader(
dataset,
batch_size=spec.batch_size,
shuffle=spec.shuffle,
num_workers=spec.num_workers,
)
def build_shortened_llm_dataloader(
spec: SharedLMDataSpec,
tokenizer,
) -> torch.utils.data.DataLoader:
dataset = TokenOnlyDataset(build_chunks(spec, tokenizer))
return torch.utils.data.DataLoader(
dataset,
batch_size=spec.batch_size,
shuffle=spec.shuffle,
num_workers=spec.num_workers,
)
def build_shortened_llm_examples(spec: SharedLMDataSpec, tokenizer) -> torch.Tensor:
chunks = build_chunks(spec, tokenizer)
if not chunks:
return torch.empty((0, spec.seq_len), dtype=torch.long)
return torch.stack(chunks, dim=0)
def build_llmpruner_examples(spec: SharedLMDataSpec, tokenizer) -> torch.Tensor:
chunks = build_chunks(spec, tokenizer)
if not chunks:
return torch.empty((0, spec.seq_len), dtype=torch.long)
return torch.stack(chunks, dim=0)
def build_replaceme_dataloader(
spec: SharedLMDataSpec,
tokenizer,
) -> torch.utils.data.DataLoader:
return build_text_dataloader(spec, tokenizer)
def build_hf_causal_dataset(spec: SharedLMDataSpec, tokenizer):
if HFDataset is None:
raise SystemExit("datasets is required for shared LM dataloaders")
chunks = build_chunks(spec, tokenizer)
payload = {
"input_ids": [chunk.tolist() for chunk in chunks],
"attention_mask": [torch.ones_like(chunk).tolist() for chunk in chunks],
"labels": [chunk.tolist() for chunk in chunks],
}
return HFDataset.from_dict(payload)