| |
| """Dataset and text helpers for fuse_layers.""" |
|
|
| import argparse |
| from typing import Dict, List, Optional |
|
|
| import torch |
|
|
| try: |
| from datasets import load_dataset |
| except Exception: |
| load_dataset = None |
|
|
|
|
| def guess_text_field(dataset) -> str: |
| if hasattr(dataset, "column_names") and dataset.column_names: |
| if "text" in dataset.column_names: |
| return "text" |
| return dataset.column_names[0] |
| if hasattr(dataset, "features"): |
| names = list(dataset.features.keys()) |
| if "text" in names: |
| return "text" |
| if names: |
| return names[0] |
| return "text" |
|
|
|
|
| def _normalize_config(config: Optional[str]) -> Optional[str]: |
| if config is None: |
| return None |
| if config.strip().lower() in {"none", "null", "-"}: |
| return None |
| return config |
|
|
|
|
| def expand_dataset_configs( |
| datasets: List[str], configs: List[str] |
| ) -> List[Optional[str]]: |
| if not configs: |
| return [None] * len(datasets) |
| if len(configs) == 1 and len(datasets) > 1: |
| return [_normalize_config(configs[0])] * len(datasets) |
| if len(configs) != len(datasets): |
| raise SystemExit( |
| "Provide zero, one, or matching-count --dataset_config values." |
| ) |
| return [_normalize_config(cfg) for cfg in configs] |
|
|
|
|
| def _sample_dataset_rows( |
| dataset, target: int, seed: int |
| ) -> List[Dict[str, object]]: |
| if target <= 0: |
| return [] |
| try: |
| dataset = dataset.shuffle(seed=seed) |
| except Exception: |
| pass |
|
|
| if hasattr(dataset, "__len__"): |
| limit = min(target, len(dataset)) |
| dataset = dataset.select(range(limit)) |
| return [row for row in dataset] |
|
|
| rows = [] |
| for row in dataset: |
| rows.append(row) |
| if len(rows) >= target: |
| break |
| return rows |
|
|
|
|
| def load_texts(args: argparse.Namespace) -> List[str]: |
| texts: List[str] = [] |
| if args.text_file: |
| with open(args.text_file, "r", encoding="utf-8") as handle: |
| texts.extend([line.strip() for line in handle if line.strip()]) |
| if args.text: |
| texts.extend([t for t in args.text if t]) |
|
|
| if args.dataset: |
| if load_dataset is None: |
| raise SystemExit("datasets is required for --dataset") |
|
|
| datasets = list(args.dataset) |
| configs = expand_dataset_configs(datasets, list(args.dataset_config)) |
| num_datasets = len(datasets) |
| base = args.num_samples // num_datasets |
| remainder = args.num_samples % num_datasets |
|
|
| for idx, (dataset_name, config) in enumerate(zip(datasets, configs)): |
| target = base + (1 if idx < remainder else 0) |
| dataset = load_dataset( |
| dataset_name, |
| config, |
| split=args.dataset_split, |
| trust_remote_code=True, |
| ) |
| rows = _sample_dataset_rows(dataset, target, args.seed + idx) |
| text_field = args.dataset_text_field or guess_text_field(dataset) |
| for row in rows: |
| value = row.get(text_field, None) if isinstance(row, dict) else None |
| if isinstance(value, str) and value.strip(): |
| texts.append(value) |
|
|
| return texts |
|
|
|
|
| def load_texts_from_datasets( |
| datasets: List[str], |
| configs: List[Optional[str]], |
| split: str, |
| text_field: Optional[str], |
| num_samples: int, |
| seed: int, |
| ) -> List[str]: |
| if not datasets: |
| return [] |
| if load_dataset is None: |
| raise SystemExit("datasets is required for --dataset") |
|
|
| texts: List[str] = [] |
| num_datasets = len(datasets) |
| base = num_samples // num_datasets |
| remainder = num_samples % num_datasets |
|
|
| for idx, (dataset_name, config) in enumerate(zip(datasets, configs)): |
| target = base + (1 if idx < remainder else 0) |
| dataset = load_dataset( |
| dataset_name, |
| config, |
| split=split, |
| trust_remote_code=True, |
| ) |
| rows = _sample_dataset_rows(dataset, target, seed + idx) |
| field = text_field or guess_text_field(dataset) |
| for row in rows: |
| value = row.get(field, None) if isinstance(row, dict) else None |
| if isinstance(value, str) and value.strip(): |
| texts.append(value) |
| return texts |
|
|
|
|
| def format_alpaca_example(instruction: str, inp: str, output: str) -> str: |
| if inp: |
| return ( |
| "### Instruction:\n" |
| f"{instruction}\n\n" |
| "### Input:\n" |
| f"{inp}\n\n" |
| "### Response:\n" |
| f"{output}" |
| ) |
| return ( |
| "### Instruction:\n" |
| f"{instruction}\n\n" |
| "### Response:\n" |
| f"{output}" |
| ) |
|
|
|
|
| def build_alpaca_messages( |
| instruction: str, inp: str, output: str |
| ) -> List[Dict[str, str]]: |
| if inp: |
| user_content = f"{instruction}\n\nInput:\n{inp}" |
| else: |
| user_content = instruction |
| return [ |
| {"role": "user", "content": user_content}, |
| {"role": "assistant", "content": output}, |
| ] |
|
|
|
|
| class FixedSeqDataset(torch.utils.data.Dataset): |
| def __init__(self, records: List[Dict[str, object]], tokenizer, seq_len: int) -> None: |
| self.records = records |
| self.tokenizer = tokenizer |
| self.seq_len = seq_len |
| self.pad_id = tokenizer.pad_token_id |
| if self.pad_id is None: |
| self.pad_id = tokenizer.eos_token_id or 0 |
|
|
| def __len__(self) -> int: |
| return len(self.records) |
|
|
| def __getitem__(self, idx: int): |
| record = self.records[idx] |
| chat_template = getattr(self.tokenizer, "chat_template", None) |
| if ( |
| "messages" in record |
| and hasattr(self.tokenizer, "apply_chat_template") |
| and chat_template |
| ): |
| ids = self.tokenizer.apply_chat_template( |
| record["messages"], |
| tokenize=True, |
| add_generation_prompt=False, |
| ) |
| else: |
| text = record.get("text", "") |
| ids = self.tokenizer.encode(text, add_special_tokens=False) |
|
|
| |
| if hasattr(ids, "input_ids"): |
| ids = ids.input_ids |
| if isinstance(ids, torch.Tensor): |
| ids = ids.tolist() |
| elif not isinstance(ids, list): |
| ids = list(ids) |
|
|
| if len(ids) > self.seq_len: |
| ids = ids[: self.seq_len] |
| attn = [1] * len(ids) |
| if len(ids) < self.seq_len: |
| pad_len = self.seq_len - len(ids) |
| ids = ids + [self.pad_id] * pad_len |
| attn = attn + [0] * pad_len |
|
|
| return ( |
| torch.tensor(ids, dtype=torch.long), |
| torch.tensor(attn, dtype=torch.long), |
| ) |
|
|
|
|
| def load_instruction_records( |
| args: argparse.Namespace, num_samples: int |
| ) -> List[Dict[str, object]]: |
| if not args.instruction_dataset: |
| return [] |
| if load_dataset is None: |
| raise SystemExit("datasets is required for instruction dataset") |
|
|
| dataset = load_dataset( |
| args.instruction_dataset, |
| _normalize_config(args.instruction_config), |
| split=args.instruction_split, |
| trust_remote_code=True, |
| ) |
| if num_samples > 0: |
| rows = _sample_dataset_rows(dataset, num_samples, args.seed) |
| else: |
| rows = dataset |
| records: List[Dict[str, object]] = [] |
| for row in rows: |
| if not isinstance(row, dict): |
| continue |
| instruction = str(row.get(args.instruction_field_instruction, "")).strip() |
| inp = str(row.get(args.instruction_field_input, "")).strip() |
| output = str(row.get(args.instruction_field_output, "")).strip() |
| if not instruction or not output: |
| continue |
| records.append( |
| { |
| "messages": build_alpaca_messages(instruction, inp, output), |
| "text": format_alpaca_example(instruction, inp, output), |
| } |
| ) |
| return records |
|
|
|
|
| def build_token_chunks( |
| texts: List[str], tokenizer, seq_len: int, num_samples: int |
| ) -> List[torch.Tensor]: |
| chunks: List[torch.Tensor] = [] |
| buffer: List[int] = [] |
| limit = None if num_samples <= 0 else num_samples |
| for text in texts: |
| ids = tokenizer.encode(text, add_special_tokens=False) |
| if not ids: |
| continue |
| buffer.extend(ids) |
| while len(buffer) >= seq_len and (limit is None or len(chunks) < limit): |
| chunk = buffer[:seq_len] |
| buffer = buffer[seq_len:] |
| chunks.append(torch.tensor(chunk, dtype=torch.long)) |
| if limit is not None and len(chunks) >= limit: |
| break |
| return chunks |
|
|