temp_ss / src /fuse_layers_data.py
LJYAI's picture
upload src
2c44909 verified
#!/usr/bin/env python3
"""Dataset and text helpers for fuse_layers."""
import argparse
from typing import Dict, List, Optional
import torch
try:
from datasets import load_dataset
except Exception: # pragma: no cover - optional dependency
load_dataset = None
def guess_text_field(dataset) -> str:
if hasattr(dataset, "column_names") and dataset.column_names:
if "text" in dataset.column_names:
return "text"
return dataset.column_names[0]
if hasattr(dataset, "features"):
names = list(dataset.features.keys())
if "text" in names:
return "text"
if names:
return names[0]
return "text"
def _normalize_config(config: Optional[str]) -> Optional[str]:
if config is None:
return None
if config.strip().lower() in {"none", "null", "-"}:
return None
return config
def expand_dataset_configs(
datasets: List[str], configs: List[str]
) -> List[Optional[str]]:
if not configs:
return [None] * len(datasets)
if len(configs) == 1 and len(datasets) > 1:
return [_normalize_config(configs[0])] * len(datasets)
if len(configs) != len(datasets):
raise SystemExit(
"Provide zero, one, or matching-count --dataset_config values."
)
return [_normalize_config(cfg) for cfg in configs]
def _sample_dataset_rows(
dataset, target: int, seed: int
) -> List[Dict[str, object]]:
if target <= 0:
return []
try:
dataset = dataset.shuffle(seed=seed)
except Exception:
pass
if hasattr(dataset, "__len__"):
limit = min(target, len(dataset))
dataset = dataset.select(range(limit))
return [row for row in dataset]
rows = []
for row in dataset:
rows.append(row)
if len(rows) >= target:
break
return rows
def load_texts(args: argparse.Namespace) -> List[str]:
texts: List[str] = []
if args.text_file:
with open(args.text_file, "r", encoding="utf-8") as handle:
texts.extend([line.strip() for line in handle if line.strip()])
if args.text:
texts.extend([t for t in args.text if t])
if args.dataset:
if load_dataset is None:
raise SystemExit("datasets is required for --dataset")
datasets = list(args.dataset)
configs = expand_dataset_configs(datasets, list(args.dataset_config))
num_datasets = len(datasets)
base = args.num_samples // num_datasets
remainder = args.num_samples % num_datasets
for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
target = base + (1 if idx < remainder else 0)
dataset = load_dataset(
dataset_name,
config,
split=args.dataset_split,
trust_remote_code=True,
)
rows = _sample_dataset_rows(dataset, target, args.seed + idx)
text_field = args.dataset_text_field or guess_text_field(dataset)
for row in rows:
value = row.get(text_field, None) if isinstance(row, dict) else None
if isinstance(value, str) and value.strip():
texts.append(value)
return texts
def load_texts_from_datasets(
datasets: List[str],
configs: List[Optional[str]],
split: str,
text_field: Optional[str],
num_samples: int,
seed: int,
) -> List[str]:
if not datasets:
return []
if load_dataset is None:
raise SystemExit("datasets is required for --dataset")
texts: List[str] = []
num_datasets = len(datasets)
base = num_samples // num_datasets
remainder = num_samples % num_datasets
for idx, (dataset_name, config) in enumerate(zip(datasets, configs)):
target = base + (1 if idx < remainder else 0)
dataset = load_dataset(
dataset_name,
config,
split=split,
trust_remote_code=True,
)
rows = _sample_dataset_rows(dataset, target, seed + idx)
field = text_field or guess_text_field(dataset)
for row in rows:
value = row.get(field, None) if isinstance(row, dict) else None
if isinstance(value, str) and value.strip():
texts.append(value)
return texts
def format_alpaca_example(instruction: str, inp: str, output: str) -> str:
if inp:
return (
"### Instruction:\n"
f"{instruction}\n\n"
"### Input:\n"
f"{inp}\n\n"
"### Response:\n"
f"{output}"
)
return (
"### Instruction:\n"
f"{instruction}\n\n"
"### Response:\n"
f"{output}"
)
def build_alpaca_messages(
instruction: str, inp: str, output: str
) -> List[Dict[str, str]]:
if inp:
user_content = f"{instruction}\n\nInput:\n{inp}"
else:
user_content = instruction
return [
{"role": "user", "content": user_content},
{"role": "assistant", "content": output},
]
class FixedSeqDataset(torch.utils.data.Dataset):
def __init__(self, records: List[Dict[str, object]], tokenizer, seq_len: int) -> None:
self.records = records
self.tokenizer = tokenizer
self.seq_len = seq_len
self.pad_id = tokenizer.pad_token_id
if self.pad_id is None:
self.pad_id = tokenizer.eos_token_id or 0
def __len__(self) -> int:
return len(self.records)
def __getitem__(self, idx: int):
record = self.records[idx]
chat_template = getattr(self.tokenizer, "chat_template", None)
if (
"messages" in record
and hasattr(self.tokenizer, "apply_chat_template")
and chat_template
):
ids = self.tokenizer.apply_chat_template(
record["messages"],
tokenize=True,
add_generation_prompt=False,
)
else:
text = record.get("text", "")
ids = self.tokenizer.encode(text, add_special_tokens=False)
# Transformers may return a BatchEncoding here instead of a plain list.
if hasattr(ids, "input_ids"):
ids = ids.input_ids
if isinstance(ids, torch.Tensor):
ids = ids.tolist()
elif not isinstance(ids, list):
ids = list(ids)
if len(ids) > self.seq_len:
ids = ids[: self.seq_len]
attn = [1] * len(ids)
if len(ids) < self.seq_len:
pad_len = self.seq_len - len(ids)
ids = ids + [self.pad_id] * pad_len
attn = attn + [0] * pad_len
return (
torch.tensor(ids, dtype=torch.long),
torch.tensor(attn, dtype=torch.long),
)
def load_instruction_records(
args: argparse.Namespace, num_samples: int
) -> List[Dict[str, object]]:
if not args.instruction_dataset:
return []
if load_dataset is None:
raise SystemExit("datasets is required for instruction dataset")
dataset = load_dataset(
args.instruction_dataset,
_normalize_config(args.instruction_config),
split=args.instruction_split,
trust_remote_code=True,
)
if num_samples > 0:
rows = _sample_dataset_rows(dataset, num_samples, args.seed)
else:
rows = dataset
records: List[Dict[str, object]] = []
for row in rows:
if not isinstance(row, dict):
continue
instruction = str(row.get(args.instruction_field_instruction, "")).strip()
inp = str(row.get(args.instruction_field_input, "")).strip()
output = str(row.get(args.instruction_field_output, "")).strip()
if not instruction or not output:
continue
records.append(
{
"messages": build_alpaca_messages(instruction, inp, output),
"text": format_alpaca_example(instruction, inp, output),
}
)
return records
def build_token_chunks(
texts: List[str], tokenizer, seq_len: int, num_samples: int
) -> List[torch.Tensor]:
chunks: List[torch.Tensor] = []
buffer: List[int] = []
limit = None if num_samples <= 0 else num_samples
for text in texts:
ids = tokenizer.encode(text, add_special_tokens=False)
if not ids:
continue
buffer.extend(ids)
while len(buffer) >= seq_len and (limit is None or len(chunks) < limit):
chunk = buffer[:seq_len]
buffer = buffer[seq_len:]
chunks.append(torch.tensor(chunk, dtype=torch.long))
if limit is not None and len(chunks) >= limit:
break
return chunks