#!/usr/bin/env python3 """Shared LM dataset helpers for fair cross-method comparisons.""" from __future__ import annotations from dataclasses import dataclass from typing import Dict, Iterable, Iterator, List, Optional, Tuple import torch try: from datasets import load_dataset from datasets import Dataset as HFDataset except Exception: # pragma: no cover - optional dependency load_dataset = None HFDataset = None def _normalize_config(config: Optional[str]) -> Optional[str]: if config is None: return None if config.strip().lower() in {"none", "null", "-"}: return None return config def guess_text_field(dataset) -> str: if hasattr(dataset, "column_names") and dataset.column_names: if "text" in dataset.column_names: return "text" return dataset.column_names[0] if hasattr(dataset, "features"): names = list(dataset.features.keys()) if "text" in names: return "text" if names: return names[0] return "text" def normalize_dataset_name(name: str) -> str: normalized = name.strip().lower() aliases = { "bookcorpus": "bookcorpus", "boockcorpus": "bookcorpus", "slimpajama": "slimpajama", "dkyoon/slimpajama-6b": "slimpajama", } if normalized not in aliases: raise ValueError(f"Unsupported dataset: {name}") return aliases[normalized] def resolve_dataset_spec( name: str, config: Optional[str] = None, split: str = "train", ) -> Tuple[str, Optional[str], str]: normalized = normalize_dataset_name(name) if normalized == "bookcorpus": return "bookcorpus", _normalize_config(config), split if normalized == "slimpajama": return "DKYoon/SlimPajama-6B", _normalize_config(config), split raise ValueError(f"Unsupported dataset: {name}") def _sample_dataset_rows(dataset, target: int, seed: int) -> List[Dict[str, object]]: if target <= 0: return [] try: dataset = dataset.shuffle(seed=seed) except Exception: pass if hasattr(dataset, "__len__"): limit = min(target, len(dataset)) dataset = dataset.select(range(limit)) return [row for row in dataset] rows = [] for row in dataset: rows.append(row) if len(rows) >= target: break return rows def _iter_dataset_rows(dataset, seed: int) -> Iterator[Dict[str, object]]: try: dataset = dataset.shuffle(seed=seed) except Exception: pass for row in dataset: yield row def load_named_texts( dataset_name: str, *, config: Optional[str] = None, split: str = "train", text_field: Optional[str] = None, num_samples: int = 0, seed: int = 0, ) -> List[str]: if load_dataset is None: raise SystemExit("datasets is required for shared LM dataloaders") hf_name, hf_config, hf_split = resolve_dataset_spec(dataset_name, config, split) dataset = load_dataset( hf_name, hf_config, split=hf_split, trust_remote_code=True, ) rows = dataset if num_samples <= 0 else _sample_dataset_rows(dataset, num_samples, seed) field = text_field or guess_text_field(dataset) texts: List[str] = [] for row in rows: value = row.get(field, None) if isinstance(row, dict) else None if isinstance(value, str) and value.strip(): texts.append(value) return texts def build_token_chunks_from_rows( rows: Iterable[Dict[str, object]], *, text_field: str, tokenizer, seq_len: int, num_sequences: int = 0, add_bos: bool = False, max_rows: int = 0, ) -> List[torch.Tensor]: chunks: List[torch.Tensor] = [] buffer: List[int] = [] limit = None if num_sequences <= 0 else num_sequences rows_seen = 0 for row in rows: if max_rows > 0 and rows_seen >= max_rows: break rows_seen += 1 value = row.get(text_field, None) if isinstance(row, dict) else None if not isinstance(value, str) or not value.strip(): continue ids = tokenizer.encode(value, add_special_tokens=False) if add_bos and tokenizer.bos_token_id is not None: ids = [tokenizer.bos_token_id] + ids if not ids: continue buffer.extend(ids) while len(buffer) >= seq_len and (limit is None or len(chunks) < limit): chunk = buffer[:seq_len] buffer = buffer[seq_len:] chunks.append(torch.tensor(chunk, dtype=torch.long)) if limit is not None and len(chunks) >= limit: break return chunks def collect_texts_from_rows( rows: Iterable[Dict[str, object]], *, text_field: str, tokenizer, target_tokens: int = 0, add_bos: bool = False, max_rows: int = 0, ) -> List[str]: texts: List[str] = [] token_count = 0 rows_seen = 0 for row in rows: if max_rows > 0 and rows_seen >= max_rows: break rows_seen += 1 value = row.get(text_field, None) if isinstance(row, dict) else None if not isinstance(value, str) or not value.strip(): continue texts.append(value) if target_tokens > 0: ids = tokenizer.encode(value, add_special_tokens=False) if add_bos and tokenizer.bos_token_id is not None: ids = [tokenizer.bos_token_id] + ids token_count += len(ids) if token_count >= target_tokens: break return texts def build_token_chunks( texts: Iterable[str], tokenizer, seq_len: int, num_sequences: int = 0, add_bos: bool = False, ) -> List[torch.Tensor]: chunks: List[torch.Tensor] = [] buffer: List[int] = [] limit = None if num_sequences <= 0 else num_sequences for text in texts: ids = tokenizer.encode(text, add_special_tokens=False) if add_bos and tokenizer.bos_token_id is not None: ids = [tokenizer.bos_token_id] + ids if not ids: continue buffer.extend(ids) while len(buffer) >= seq_len and (limit is None or len(chunks) < limit): chunk = buffer[:seq_len] buffer = buffer[seq_len:] chunks.append(torch.tensor(chunk, dtype=torch.long)) if limit is not None and len(chunks) >= limit: break return chunks class TokenChunkDataset(torch.utils.data.Dataset): def __init__(self, chunks: List[torch.Tensor]) -> None: self.chunks = chunks def __len__(self) -> int: return len(self.chunks) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: input_ids = self.chunks[idx] attention_mask = torch.ones_like(input_ids) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids.clone(), } class TokenOnlyDataset(torch.utils.data.Dataset): def __init__(self, chunks: List[torch.Tensor]) -> None: self.chunks = chunks def __len__(self) -> int: return len(self.chunks) def __getitem__(self, idx: int) -> torch.Tensor: return self.chunks[idx] class TokenInputMaskDataset(torch.utils.data.Dataset): def __init__(self, chunks: List[torch.Tensor]) -> None: self.chunks = chunks def __len__(self) -> int: return len(self.chunks) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: input_ids = self.chunks[idx] return { "input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), } @dataclass class SharedLMDataSpec: dataset: str config: Optional[str] = None split: str = "train" text_field: Optional[str] = None num_samples: int = 0 seq_len: int = 2048 num_sequences: int = 0 target_tokens: int = 0 batch_size: int = 1 shuffle: bool = False num_workers: int = 0 seed: int = 0 add_bos: bool = False def build_chunks(spec: SharedLMDataSpec, tokenizer) -> List[torch.Tensor]: if load_dataset is None: raise SystemExit("datasets is required for shared LM dataloaders") hf_name, hf_config, hf_split = resolve_dataset_spec(spec.dataset, spec.config, spec.split) dataset = load_dataset( hf_name, hf_config, split=hf_split, trust_remote_code=True, ) target_sequences = spec.num_sequences if spec.target_tokens > 0: token_sequences = (spec.target_tokens + spec.seq_len - 1) // spec.seq_len target_sequences = max(target_sequences, token_sequences) row_limit = spec.num_samples if target_sequences <= 0 else 0 rows = _iter_dataset_rows(dataset, spec.seed) text_field = spec.text_field or guess_text_field(dataset) chunks = build_token_chunks_from_rows( rows, text_field=text_field, tokenizer=tokenizer, seq_len=spec.seq_len, num_sequences=target_sequences, add_bos=spec.add_bos, max_rows=row_limit, ) return chunks def build_dataloader(spec: SharedLMDataSpec, tokenizer) -> torch.utils.data.DataLoader: chunks = build_chunks(spec, tokenizer) dataset = TokenChunkDataset(chunks) return torch.utils.data.DataLoader( dataset, batch_size=spec.batch_size, shuffle=spec.shuffle, num_workers=spec.num_workers, ) def build_text_dataloader(spec: SharedLMDataSpec, tokenizer) -> torch.utils.data.DataLoader: if load_dataset is None: raise SystemExit("datasets is required for shared LM dataloaders") hf_name, hf_config, hf_split = resolve_dataset_spec(spec.dataset, spec.config, spec.split) dataset = load_dataset( hf_name, hf_config, split=hf_split, trust_remote_code=True, ) rows = _iter_dataset_rows(dataset, spec.seed) text_field = spec.text_field or guess_text_field(dataset) row_limit = spec.num_samples texts = collect_texts_from_rows( rows, text_field=text_field, tokenizer=tokenizer, target_tokens=spec.target_tokens, add_bos=spec.add_bos, max_rows=row_limit, ) return torch.utils.data.DataLoader( texts, batch_size=spec.batch_size, shuffle=spec.shuffle, num_workers=spec.num_workers, drop_last=True, ) def build_uidl_post_train_dataloader( spec: SharedLMDataSpec, tokenizer, ) -> torch.utils.data.DataLoader: dataset = TokenChunkDataset(build_chunks(spec, tokenizer)) return torch.utils.data.DataLoader( dataset, batch_size=spec.batch_size, shuffle=spec.shuffle, num_workers=spec.num_workers, ) def build_uidl_similarity_dataloader( spec: SharedLMDataSpec, tokenizer, ) -> torch.utils.data.DataLoader: dataset = TokenInputMaskDataset(build_chunks(spec, tokenizer)) return torch.utils.data.DataLoader( dataset, batch_size=spec.batch_size, shuffle=spec.shuffle, num_workers=spec.num_workers, ) def build_shortened_llm_dataloader( spec: SharedLMDataSpec, tokenizer, ) -> torch.utils.data.DataLoader: dataset = TokenOnlyDataset(build_chunks(spec, tokenizer)) return torch.utils.data.DataLoader( dataset, batch_size=spec.batch_size, shuffle=spec.shuffle, num_workers=spec.num_workers, ) def build_shortened_llm_examples(spec: SharedLMDataSpec, tokenizer) -> torch.Tensor: chunks = build_chunks(spec, tokenizer) if not chunks: return torch.empty((0, spec.seq_len), dtype=torch.long) return torch.stack(chunks, dim=0) def build_llmpruner_examples(spec: SharedLMDataSpec, tokenizer) -> torch.Tensor: chunks = build_chunks(spec, tokenizer) if not chunks: return torch.empty((0, spec.seq_len), dtype=torch.long) return torch.stack(chunks, dim=0) def build_replaceme_dataloader( spec: SharedLMDataSpec, tokenizer, ) -> torch.utils.data.DataLoader: return build_text_dataloader(spec, tokenizer) def build_hf_causal_dataset(spec: SharedLMDataSpec, tokenizer): if HFDataset is None: raise SystemExit("datasets is required for shared LM dataloaders") chunks = build_chunks(spec, tokenizer) payload = { "input_ids": [chunk.tolist() for chunk in chunks], "attention_mask": [torch.ones_like(chunk).tolist() for chunk in chunks], "labels": [chunk.tolist() for chunk in chunks], } return HFDataset.from_dict(payload)