| |
| """Shared LM dataset helpers for fair cross-method comparisons.""" |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Dict, Iterable, Iterator, List, Optional, Tuple |
|
|
| import torch |
|
|
| try: |
| from datasets import load_dataset |
| from datasets import Dataset as HFDataset |
| except Exception: |
| load_dataset = None |
| HFDataset = None |
|
|
|
|
| def _normalize_config(config: Optional[str]) -> Optional[str]: |
| if config is None: |
| return None |
| if config.strip().lower() in {"none", "null", "-"}: |
| return None |
| return config |
|
|
|
|
| def guess_text_field(dataset) -> str: |
| if hasattr(dataset, "column_names") and dataset.column_names: |
| if "text" in dataset.column_names: |
| return "text" |
| return dataset.column_names[0] |
| if hasattr(dataset, "features"): |
| names = list(dataset.features.keys()) |
| if "text" in names: |
| return "text" |
| if names: |
| return names[0] |
| return "text" |
|
|
|
|
| def normalize_dataset_name(name: str) -> str: |
| normalized = name.strip().lower() |
| aliases = { |
| "bookcorpus": "bookcorpus", |
| "boockcorpus": "bookcorpus", |
| "slimpajama": "slimpajama", |
| "dkyoon/slimpajama-6b": "slimpajama", |
| } |
| if normalized not in aliases: |
| raise ValueError(f"Unsupported dataset: {name}") |
| return aliases[normalized] |
|
|
|
|
| def resolve_dataset_spec( |
| name: str, |
| config: Optional[str] = None, |
| split: str = "train", |
| ) -> Tuple[str, Optional[str], str]: |
| normalized = normalize_dataset_name(name) |
| if normalized == "bookcorpus": |
| return "bookcorpus", _normalize_config(config), split |
| if normalized == "slimpajama": |
| return "DKYoon/SlimPajama-6B", _normalize_config(config), split |
| raise ValueError(f"Unsupported dataset: {name}") |
|
|
|
|
| def _sample_dataset_rows(dataset, target: int, seed: int) -> List[Dict[str, object]]: |
| if target <= 0: |
| return [] |
| try: |
| dataset = dataset.shuffle(seed=seed) |
| except Exception: |
| pass |
|
|
| if hasattr(dataset, "__len__"): |
| limit = min(target, len(dataset)) |
| dataset = dataset.select(range(limit)) |
| return [row for row in dataset] |
|
|
| rows = [] |
| for row in dataset: |
| rows.append(row) |
| if len(rows) >= target: |
| break |
| return rows |
|
|
|
|
| def _iter_dataset_rows(dataset, seed: int) -> Iterator[Dict[str, object]]: |
| try: |
| dataset = dataset.shuffle(seed=seed) |
| except Exception: |
| pass |
| for row in dataset: |
| yield row |
|
|
|
|
| def load_named_texts( |
| dataset_name: str, |
| *, |
| config: Optional[str] = None, |
| split: str = "train", |
| text_field: Optional[str] = None, |
| num_samples: int = 0, |
| seed: int = 0, |
| ) -> List[str]: |
| if load_dataset is None: |
| raise SystemExit("datasets is required for shared LM dataloaders") |
|
|
| hf_name, hf_config, hf_split = resolve_dataset_spec(dataset_name, config, split) |
| dataset = load_dataset( |
| hf_name, |
| hf_config, |
| split=hf_split, |
| trust_remote_code=True, |
| ) |
| rows = dataset if num_samples <= 0 else _sample_dataset_rows(dataset, num_samples, seed) |
| field = text_field or guess_text_field(dataset) |
|
|
| texts: List[str] = [] |
| for row in rows: |
| value = row.get(field, None) if isinstance(row, dict) else None |
| if isinstance(value, str) and value.strip(): |
| texts.append(value) |
| return texts |
|
|
|
|
| def build_token_chunks_from_rows( |
| rows: Iterable[Dict[str, object]], |
| *, |
| text_field: str, |
| tokenizer, |
| seq_len: int, |
| num_sequences: int = 0, |
| add_bos: bool = False, |
| max_rows: int = 0, |
| ) -> List[torch.Tensor]: |
| chunks: List[torch.Tensor] = [] |
| buffer: List[int] = [] |
| limit = None if num_sequences <= 0 else num_sequences |
| rows_seen = 0 |
|
|
| for row in rows: |
| if max_rows > 0 and rows_seen >= max_rows: |
| break |
| rows_seen += 1 |
|
|
| value = row.get(text_field, None) if isinstance(row, dict) else None |
| if not isinstance(value, str) or not value.strip(): |
| continue |
|
|
| ids = tokenizer.encode(value, add_special_tokens=False) |
| if add_bos and tokenizer.bos_token_id is not None: |
| ids = [tokenizer.bos_token_id] + ids |
| if not ids: |
| continue |
|
|
| buffer.extend(ids) |
| while len(buffer) >= seq_len and (limit is None or len(chunks) < limit): |
| chunk = buffer[:seq_len] |
| buffer = buffer[seq_len:] |
| chunks.append(torch.tensor(chunk, dtype=torch.long)) |
| if limit is not None and len(chunks) >= limit: |
| break |
|
|
| return chunks |
|
|
|
|
| def collect_texts_from_rows( |
| rows: Iterable[Dict[str, object]], |
| *, |
| text_field: str, |
| tokenizer, |
| target_tokens: int = 0, |
| add_bos: bool = False, |
| max_rows: int = 0, |
| ) -> List[str]: |
| texts: List[str] = [] |
| token_count = 0 |
| rows_seen = 0 |
|
|
| for row in rows: |
| if max_rows > 0 and rows_seen >= max_rows: |
| break |
| rows_seen += 1 |
|
|
| value = row.get(text_field, None) if isinstance(row, dict) else None |
| if not isinstance(value, str) or not value.strip(): |
| continue |
|
|
| texts.append(value) |
| if target_tokens > 0: |
| ids = tokenizer.encode(value, add_special_tokens=False) |
| if add_bos and tokenizer.bos_token_id is not None: |
| ids = [tokenizer.bos_token_id] + ids |
| token_count += len(ids) |
| if token_count >= target_tokens: |
| break |
|
|
| return texts |
|
|
|
|
| def build_token_chunks( |
| texts: Iterable[str], |
| tokenizer, |
| seq_len: int, |
| num_sequences: int = 0, |
| add_bos: bool = False, |
| ) -> List[torch.Tensor]: |
| chunks: List[torch.Tensor] = [] |
| buffer: List[int] = [] |
| limit = None if num_sequences <= 0 else num_sequences |
|
|
| for text in texts: |
| ids = tokenizer.encode(text, add_special_tokens=False) |
| if add_bos and tokenizer.bos_token_id is not None: |
| ids = [tokenizer.bos_token_id] + ids |
| if not ids: |
| continue |
|
|
| buffer.extend(ids) |
| while len(buffer) >= seq_len and (limit is None or len(chunks) < limit): |
| chunk = buffer[:seq_len] |
| buffer = buffer[seq_len:] |
| chunks.append(torch.tensor(chunk, dtype=torch.long)) |
| if limit is not None and len(chunks) >= limit: |
| break |
|
|
| return chunks |
|
|
|
|
| class TokenChunkDataset(torch.utils.data.Dataset): |
| def __init__(self, chunks: List[torch.Tensor]) -> None: |
| self.chunks = chunks |
|
|
| def __len__(self) -> int: |
| return len(self.chunks) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| input_ids = self.chunks[idx] |
| attention_mask = torch.ones_like(input_ids) |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": input_ids.clone(), |
| } |
|
|
|
|
| class TokenOnlyDataset(torch.utils.data.Dataset): |
| def __init__(self, chunks: List[torch.Tensor]) -> None: |
| self.chunks = chunks |
|
|
| def __len__(self) -> int: |
| return len(self.chunks) |
|
|
| def __getitem__(self, idx: int) -> torch.Tensor: |
| return self.chunks[idx] |
|
|
|
|
| class TokenInputMaskDataset(torch.utils.data.Dataset): |
| def __init__(self, chunks: List[torch.Tensor]) -> None: |
| self.chunks = chunks |
|
|
| def __len__(self) -> int: |
| return len(self.chunks) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| input_ids = self.chunks[idx] |
| return { |
| "input_ids": input_ids, |
| "attention_mask": torch.ones_like(input_ids), |
| } |
|
|
|
|
| @dataclass |
| class SharedLMDataSpec: |
| dataset: str |
| config: Optional[str] = None |
| split: str = "train" |
| text_field: Optional[str] = None |
| num_samples: int = 0 |
| seq_len: int = 2048 |
| num_sequences: int = 0 |
| target_tokens: int = 0 |
| batch_size: int = 1 |
| shuffle: bool = False |
| num_workers: int = 0 |
| seed: int = 0 |
| add_bos: bool = False |
|
|
|
|
| def build_chunks(spec: SharedLMDataSpec, tokenizer) -> List[torch.Tensor]: |
| if load_dataset is None: |
| raise SystemExit("datasets is required for shared LM dataloaders") |
|
|
| hf_name, hf_config, hf_split = resolve_dataset_spec(spec.dataset, spec.config, spec.split) |
| dataset = load_dataset( |
| hf_name, |
| hf_config, |
| split=hf_split, |
| trust_remote_code=True, |
| ) |
|
|
| target_sequences = spec.num_sequences |
| if spec.target_tokens > 0: |
| token_sequences = (spec.target_tokens + spec.seq_len - 1) // spec.seq_len |
| target_sequences = max(target_sequences, token_sequences) |
| row_limit = spec.num_samples if target_sequences <= 0 else 0 |
|
|
| rows = _iter_dataset_rows(dataset, spec.seed) |
| text_field = spec.text_field or guess_text_field(dataset) |
| chunks = build_token_chunks_from_rows( |
| rows, |
| text_field=text_field, |
| tokenizer=tokenizer, |
| seq_len=spec.seq_len, |
| num_sequences=target_sequences, |
| add_bos=spec.add_bos, |
| max_rows=row_limit, |
| ) |
| return chunks |
|
|
|
|
| def build_dataloader(spec: SharedLMDataSpec, tokenizer) -> torch.utils.data.DataLoader: |
| chunks = build_chunks(spec, tokenizer) |
| dataset = TokenChunkDataset(chunks) |
| return torch.utils.data.DataLoader( |
| dataset, |
| batch_size=spec.batch_size, |
| shuffle=spec.shuffle, |
| num_workers=spec.num_workers, |
| ) |
|
|
|
|
| def build_text_dataloader(spec: SharedLMDataSpec, tokenizer) -> torch.utils.data.DataLoader: |
| if load_dataset is None: |
| raise SystemExit("datasets is required for shared LM dataloaders") |
|
|
| hf_name, hf_config, hf_split = resolve_dataset_spec(spec.dataset, spec.config, spec.split) |
| dataset = load_dataset( |
| hf_name, |
| hf_config, |
| split=hf_split, |
| trust_remote_code=True, |
| ) |
| rows = _iter_dataset_rows(dataset, spec.seed) |
| text_field = spec.text_field or guess_text_field(dataset) |
| row_limit = spec.num_samples |
| texts = collect_texts_from_rows( |
| rows, |
| text_field=text_field, |
| tokenizer=tokenizer, |
| target_tokens=spec.target_tokens, |
| add_bos=spec.add_bos, |
| max_rows=row_limit, |
| ) |
| return torch.utils.data.DataLoader( |
| texts, |
| batch_size=spec.batch_size, |
| shuffle=spec.shuffle, |
| num_workers=spec.num_workers, |
| drop_last=True, |
| ) |
|
|
|
|
| def build_uidl_post_train_dataloader( |
| spec: SharedLMDataSpec, |
| tokenizer, |
| ) -> torch.utils.data.DataLoader: |
| dataset = TokenChunkDataset(build_chunks(spec, tokenizer)) |
| return torch.utils.data.DataLoader( |
| dataset, |
| batch_size=spec.batch_size, |
| shuffle=spec.shuffle, |
| num_workers=spec.num_workers, |
| ) |
|
|
|
|
| def build_uidl_similarity_dataloader( |
| spec: SharedLMDataSpec, |
| tokenizer, |
| ) -> torch.utils.data.DataLoader: |
| dataset = TokenInputMaskDataset(build_chunks(spec, tokenizer)) |
| return torch.utils.data.DataLoader( |
| dataset, |
| batch_size=spec.batch_size, |
| shuffle=spec.shuffle, |
| num_workers=spec.num_workers, |
| ) |
|
|
|
|
| def build_shortened_llm_dataloader( |
| spec: SharedLMDataSpec, |
| tokenizer, |
| ) -> torch.utils.data.DataLoader: |
| dataset = TokenOnlyDataset(build_chunks(spec, tokenizer)) |
| return torch.utils.data.DataLoader( |
| dataset, |
| batch_size=spec.batch_size, |
| shuffle=spec.shuffle, |
| num_workers=spec.num_workers, |
| ) |
|
|
|
|
| def build_shortened_llm_examples(spec: SharedLMDataSpec, tokenizer) -> torch.Tensor: |
| chunks = build_chunks(spec, tokenizer) |
| if not chunks: |
| return torch.empty((0, spec.seq_len), dtype=torch.long) |
| return torch.stack(chunks, dim=0) |
|
|
|
|
| def build_llmpruner_examples(spec: SharedLMDataSpec, tokenizer) -> torch.Tensor: |
| chunks = build_chunks(spec, tokenizer) |
| if not chunks: |
| return torch.empty((0, spec.seq_len), dtype=torch.long) |
| return torch.stack(chunks, dim=0) |
|
|
|
|
| def build_replaceme_dataloader( |
| spec: SharedLMDataSpec, |
| tokenizer, |
| ) -> torch.utils.data.DataLoader: |
| return build_text_dataloader(spec, tokenizer) |
|
|
|
|
| def build_hf_causal_dataset(spec: SharedLMDataSpec, tokenizer): |
| if HFDataset is None: |
| raise SystemExit("datasets is required for shared LM dataloaders") |
|
|
| chunks = build_chunks(spec, tokenizer) |
| payload = { |
| "input_ids": [chunk.tolist() for chunk in chunks], |
| "attention_mask": [torch.ones_like(chunk).tolist() for chunk in chunks], |
| "labels": [chunk.tolist() for chunk in chunks], |
| } |
| return HFDataset.from_dict(payload) |
|
|