Spaces:
Sleeping
Sleeping
| import json | |
| from typing import Dict, Optional | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from utils.encoder_utils import build_self_attn_cond_masks | |
| from utils.logging_utils import log_for_0 | |
| def _process_count() -> int: | |
| try: | |
| import torch.distributed as dist | |
| if dist.is_available() and dist.is_initialized(): | |
| return dist.get_world_size() | |
| except Exception: | |
| pass | |
| return 1 | |
| def _process_index() -> int: | |
| try: | |
| import torch.distributed as dist | |
| if dist.is_available() and dist.is_initialized(): | |
| return dist.get_rank() | |
| except Exception: | |
| pass | |
| return 0 | |
| def get_pad_token_id(tokenizer, pad_token: str = "pad") -> int: | |
| """Resolve the token id used for padding, optionally using EOS as pad.""" | |
| token_id = tokenizer.eos_token_id if pad_token == "eos" else tokenizer.pad_token_id | |
| if token_id is None: | |
| raise ValueError("Tokenizer has no pad_token_id or eos_token_id.") | |
| return token_id | |
| def prepare_batch(batch: Dict, config, generator: torch.Generator) -> Dict: | |
| """Convert numpy batch to torch tensors and sample label-drop decisions.""" | |
| result = {} | |
| for k, v in batch.items(): | |
| if isinstance(v, np.ndarray): | |
| result[k] = torch.from_numpy(v) | |
| elif isinstance(v, torch.Tensor): | |
| result[k] = v | |
| else: | |
| result[k] = v | |
| batch_size = result["input_ids"].shape[0] | |
| label_drop_mask = torch.zeros((batch_size,), dtype=torch.bool) | |
| if config.label_drop_prob > 0: | |
| u = torch.rand((batch_size,), generator=generator) | |
| label_drop_mask = u < config.label_drop_prob | |
| result["label_drop_mask"] = label_drop_mask | |
| return result | |
| def pad_and_truncate(ids_list, target_len, pad_token_id): | |
| """Pad or truncate sequences to target_len, return stacked array and lengths.""" | |
| padded, lengths = [], [] | |
| for ids in ids_list: | |
| orig_len = min(len(ids), target_len) | |
| ids = ids[:target_len] | |
| if orig_len < target_len: | |
| ids = np.concatenate([ids, np.full(target_len - orig_len, pad_token_id, dtype=ids.dtype)]) | |
| padded.append(ids) | |
| lengths.append(orig_len) | |
| return np.stack(padded), np.array(lengths) | |
| def get_dataloader( | |
| dataset, | |
| batch_size: int, | |
| shuffle: bool = True, | |
| num_workers: int = 0, | |
| drop_last: bool = True, | |
| max_seq_length: int = 512, | |
| pad_token_id: int = 0, | |
| max_input_seq_length: Optional[int] = None, | |
| distributed: bool = True, | |
| ): | |
| """Create a DataLoader.""" | |
| def collate_fn(batch_list): | |
| input_ids_list = [np.array(item["input_ids"]) for item in batch_list] | |
| if "condition_input_ids" in batch_list[0]: | |
| seq_list, cond_lens = [], [] | |
| for item in batch_list: | |
| cond = np.array(item["condition_input_ids"])[:max_input_seq_length] | |
| inp = np.array(item["input_ids"]) | |
| seq_list.append(np.concatenate([cond, inp])) | |
| cond_lens.append(len(cond)) | |
| cond_lens = np.array(cond_lens) | |
| else: | |
| seq_list = input_ids_list | |
| cond_lens = np.zeros(len(input_ids_list), dtype=np.int32) | |
| ids, total_lens = pad_and_truncate(seq_list, max_seq_length, pad_token_id) | |
| pos = np.arange(max_seq_length)[None, :] | |
| is_cond = pos < cond_lens[:, None] | |
| is_valid = pos < total_lens[:, None] | |
| encoder_attn, attn, pred = build_self_attn_cond_masks(is_cond, is_valid, xp=np) | |
| result = { | |
| "input_ids": ids, | |
| "encoder_attention_mask": encoder_attn, | |
| "attention_mask": attn, | |
| "cond_seq_mask": pred, | |
| } | |
| for key in ("index", "input", "target"): | |
| if key in batch_list[0]: | |
| result[key] = [item[key] for item in batch_list] | |
| return result | |
| common = dict( | |
| batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, | |
| drop_last=drop_last, persistent_workers=num_workers > 0, | |
| pin_memory=True, | |
| ) | |
| if distributed: | |
| sampler = DistributedSampler( | |
| dataset, num_replicas=_process_count(), rank=_process_index(), | |
| shuffle=shuffle, drop_last=drop_last, | |
| ) | |
| return DataLoader(dataset, sampler=sampler, **common) | |
| return DataLoader(dataset, shuffle=shuffle, **common) | |
| def load_jsonl_dataset(path, tokenizer, input_key="input", output_key="output"): | |
| """Load a JSONL eval set (one `{input, output}` example per line).""" | |
| examples = [] | |
| with open(path, "r", encoding="utf-8") as f: | |
| for i, line in enumerate(f): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| data = json.loads(line) | |
| examples.append({ | |
| "index": i, | |
| "input": data[input_key], | |
| "target": data[output_key], | |
| "condition_input_ids": tokenizer(data[input_key], add_special_tokens=False)["input_ids"], | |
| "input_ids": tokenizer(data[output_key], add_special_tokens=False)["input_ids"], | |
| }) | |
| return examples | |
| # ============================================ | |
| # Dataset loading | |
| # ============================================ | |
| def _looks_like_save_to_disk_arrow(ds) -> bool: | |
| """Detect HF datasets uploaded via `save_to_disk` (returns 1-row of metadata).""" | |
| return ( | |
| len(ds) == 1 | |
| and any(c.startswith("_") for c in ds.column_names) | |
| and not any(not c.startswith("_") for c in ds.column_names) | |
| ) | |
| def load_dataset_split(path: str, dataset_cache_dir=None): | |
| """Load a dataset. Tries HuggingFace Hub first; falls back to local on-disk Arrow.""" | |
| from datasets import DatasetDict, load_dataset as hf_load_dataset, load_from_disk | |
| ds = None | |
| try: | |
| ds = hf_load_dataset(path, cache_dir=dataset_cache_dir) | |
| except Exception: | |
| ds = load_from_disk(path) | |
| if isinstance(ds, DatasetDict): | |
| splits = list(ds.keys()) | |
| if len(splits) != 1: | |
| raise ValueError(f"Expected dataset at {path!r} to have a single split, got {splits}.") | |
| ds = ds[splits[0]] | |
| if _looks_like_save_to_disk_arrow(ds): | |
| from huggingface_hub import snapshot_download | |
| log_for_0( | |
| f"Dataset at {path!r} looks like a save_to_disk-format HF repo; " | |
| f"re-downloading via snapshot_download + load_from_disk." | |
| ) | |
| local_dir = snapshot_download(repo_id=path, repo_type="dataset", cache_dir=dataset_cache_dir) | |
| ds = load_from_disk(local_dir) | |
| if isinstance(ds, DatasetDict): | |
| splits = list(ds.keys()) | |
| if len(splits) != 1: | |
| raise ValueError(f"Expected dataset at {path!r} to have a single split, got {splits}.") | |
| ds = ds[splits[0]] | |
| ds.set_format(type="numpy", columns=ds.column_names) | |
| return ds | |
| def load_dataset(config, dataset_cache_dir=None): | |
| """Resolve config.data_path / config.eval_data_path into train/eval datasets.""" | |
| log_for_0(f"Loading dataset from {config.data_path}...") | |
| train_dataset = load_dataset_split(config.data_path, dataset_cache_dir) | |
| log_for_0(f"Train size: {len(train_dataset)}") | |
| eval_dataset = None | |
| if config.eval_data_path: | |
| eval_dataset = load_dataset_split(config.eval_data_path, dataset_cache_dir) | |
| log_for_0(f"Eval size: {len(eval_dataset)}") | |
| else: | |
| log_for_0("No eval dataset") | |
| return train_dataset, eval_dataset | |