#!/usr/bin/env python3 """Dataset and text helpers for fuse_layers.""" import argparse from typing import Dict, List, Optional import torch try: from datasets import load_dataset except Exception: # pragma: no cover - optional dependency load_dataset = None def guess_text_field(dataset) -> str: if hasattr(dataset, "column_names") and dataset.column_names: if "text" in dataset.column_names: return "text" return dataset.column_names[0] if hasattr(dataset, "features"): names = list(dataset.features.keys()) if "text" in names: return "text" if names: return names[0] return "text" def _normalize_config(config: Optional[str]) -> Optional[str]: if config is None: return None if config.strip().lower() in {"none", "null", "-"}: return None return config def expand_dataset_configs( datasets: List[str], configs: List[str] ) -> List[Optional[str]]: if not configs: return [None] * len(datasets) if len(configs) == 1 and len(datasets) > 1: return [_normalize_config(configs[0])] * len(datasets) if len(configs) != len(datasets): raise SystemExit( "Provide zero, one, or matching-count --dataset_config values." ) return [_normalize_config(cfg) for cfg in configs] def _sample_dataset_rows( dataset, target: int, seed: int ) -> List[Dict[str, object]]: if target <= 0: return [] try: dataset = dataset.shuffle(seed=seed) except Exception: pass if hasattr(dataset, "__len__"): limit = min(target, len(dataset)) dataset = dataset.select(range(limit)) return [row for row in dataset] rows = [] for row in dataset: rows.append(row) if len(rows) >= target: break return rows def load_texts(args: argparse.Namespace) -> List[str]: texts: List[str] = [] if args.text_file: with open(args.text_file, "r", encoding="utf-8") as handle: texts.extend([line.strip() for line in handle if line.strip()]) if args.text: texts.extend([t for t in args.text if t]) if args.dataset: if load_dataset is None: raise SystemExit("datasets is required for --dataset") datasets = list(args.dataset) configs = expand_dataset_configs(datasets, list(args.dataset_config)) num_datasets = len(datasets) base = args.num_samples // num_datasets remainder = args.num_samples % num_datasets for idx, (dataset_name, config) in enumerate(zip(datasets, configs)): target = base + (1 if idx < remainder else 0) dataset = load_dataset( dataset_name, config, split=args.dataset_split, trust_remote_code=True, ) rows = _sample_dataset_rows(dataset, target, args.seed + idx) text_field = args.dataset_text_field or guess_text_field(dataset) for row in rows: value = row.get(text_field, None) if isinstance(row, dict) else None if isinstance(value, str) and value.strip(): texts.append(value) return texts def load_texts_from_datasets( datasets: List[str], configs: List[Optional[str]], split: str, text_field: Optional[str], num_samples: int, seed: int, ) -> List[str]: if not datasets: return [] if load_dataset is None: raise SystemExit("datasets is required for --dataset") texts: List[str] = [] num_datasets = len(datasets) base = num_samples // num_datasets remainder = num_samples % num_datasets for idx, (dataset_name, config) in enumerate(zip(datasets, configs)): target = base + (1 if idx < remainder else 0) dataset = load_dataset( dataset_name, config, split=split, trust_remote_code=True, ) rows = _sample_dataset_rows(dataset, target, seed + idx) field = text_field or guess_text_field(dataset) for row in rows: value = row.get(field, None) if isinstance(row, dict) else None if isinstance(value, str) and value.strip(): texts.append(value) return texts def format_alpaca_example(instruction: str, inp: str, output: str) -> str: if inp: return ( "### Instruction:\n" f"{instruction}\n\n" "### Input:\n" f"{inp}\n\n" "### Response:\n" f"{output}" ) return ( "### Instruction:\n" f"{instruction}\n\n" "### Response:\n" f"{output}" ) def build_alpaca_messages( instruction: str, inp: str, output: str ) -> List[Dict[str, str]]: if inp: user_content = f"{instruction}\n\nInput:\n{inp}" else: user_content = instruction return [ {"role": "user", "content": user_content}, {"role": "assistant", "content": output}, ] class FixedSeqDataset(torch.utils.data.Dataset): def __init__(self, records: List[Dict[str, object]], tokenizer, seq_len: int) -> None: self.records = records self.tokenizer = tokenizer self.seq_len = seq_len self.pad_id = tokenizer.pad_token_id if self.pad_id is None: self.pad_id = tokenizer.eos_token_id or 0 def __len__(self) -> int: return len(self.records) def __getitem__(self, idx: int): record = self.records[idx] chat_template = getattr(self.tokenizer, "chat_template", None) if ( "messages" in record and hasattr(self.tokenizer, "apply_chat_template") and chat_template ): ids = self.tokenizer.apply_chat_template( record["messages"], tokenize=True, add_generation_prompt=False, ) else: text = record.get("text", "") ids = self.tokenizer.encode(text, add_special_tokens=False) # Transformers may return a BatchEncoding here instead of a plain list. if hasattr(ids, "input_ids"): ids = ids.input_ids if isinstance(ids, torch.Tensor): ids = ids.tolist() elif not isinstance(ids, list): ids = list(ids) if len(ids) > self.seq_len: ids = ids[: self.seq_len] attn = [1] * len(ids) if len(ids) < self.seq_len: pad_len = self.seq_len - len(ids) ids = ids + [self.pad_id] * pad_len attn = attn + [0] * pad_len return ( torch.tensor(ids, dtype=torch.long), torch.tensor(attn, dtype=torch.long), ) def load_instruction_records( args: argparse.Namespace, num_samples: int ) -> List[Dict[str, object]]: if not args.instruction_dataset: return [] if load_dataset is None: raise SystemExit("datasets is required for instruction dataset") dataset = load_dataset( args.instruction_dataset, _normalize_config(args.instruction_config), split=args.instruction_split, trust_remote_code=True, ) if num_samples > 0: rows = _sample_dataset_rows(dataset, num_samples, args.seed) else: rows = dataset records: List[Dict[str, object]] = [] for row in rows: if not isinstance(row, dict): continue instruction = str(row.get(args.instruction_field_instruction, "")).strip() inp = str(row.get(args.instruction_field_input, "")).strip() output = str(row.get(args.instruction_field_output, "")).strip() if not instruction or not output: continue records.append( { "messages": build_alpaca_messages(instruction, inp, output), "text": format_alpaca_example(instruction, inp, output), } ) return records def build_token_chunks( texts: List[str], tokenizer, seq_len: int, num_samples: int ) -> List[torch.Tensor]: chunks: List[torch.Tensor] = [] buffer: List[int] = [] limit = None if num_samples <= 0 else num_samples for text in texts: ids = tokenizer.encode(text, add_special_tokens=False) if not ids: continue buffer.extend(ids) while len(buffer) >= seq_len and (limit is None or len(chunks) < limit): chunk = buffer[:seq_len] buffer = buffer[seq_len:] chunks.append(torch.tensor(chunk, dtype=torch.long)) if limit is not None and len(chunks) >= limit: break return chunks