| | import json |
| | from pathlib import Path |
| | from typing import Callable, Optional |
| |
|
| | import torch |
| | from megatron.core import parallel_state |
| | from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
| | from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( |
| | MegatronPretrainingRandomSampler, |
| | MegatronPretrainingSampler, |
| | ) |
| | from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( |
| | MegatronPretrainingBatchSampler, |
| | MegatronPretrainingRandomBatchSampler, |
| | ) |
| | from nemo.core.classes import Dataset |
| | from nemo.utils import logging |
| | from nemo.utils.get_rank import is_global_rank_zero |
| | from omegaconf import DictConfig |
| | from torch.utils.data import DataLoader |
| |
|
| |
|
| | def build_dataloader( |
| | dataset: Dataset, |
| | consumed_samples: int, |
| | micro_batch_size: int, |
| | global_batch_size: int, |
| | collate_fn: Optional[Callable] = None, |
| | seed: Optional[int] = None, |
| | ) -> DataLoader: |
| | common_params: dict = { |
| | "total_samples": len(dataset), |
| | "consumed_samples": consumed_samples, |
| | "micro_batch_size": micro_batch_size, |
| | "global_batch_size": global_batch_size, |
| | "data_parallel_rank": parallel_state.get_data_parallel_rank(), |
| | "data_parallel_size": parallel_state.get_data_parallel_world_size(), |
| | "drop_last": True, |
| | "pad_samples_to_global_batch_size": False, |
| | } |
| |
|
| | if seed is not None and seed >= 0: |
| | batch_sampler = MegatronPretrainingRandomBatchSampler( |
| | **common_params, seed=seed |
| | ) |
| | else: |
| | batch_sampler = MegatronPretrainingBatchSampler(**common_params) |
| |
|
| | return DataLoader( |
| | dataset, |
| | batch_sampler=batch_sampler, |
| | num_workers=0, |
| | pin_memory=True, |
| | collate_fn=collate_fn, |
| | ) |
| |
|
| |
|
| | def custom_build_dataloader( |
| | dataset: Dataset, |
| | consumed_samples: int, |
| | mbs: int, |
| | gbs: int, |
| | num_workers: int = 0, |
| | drop_last: bool = True, |
| | pad_samples_to_global_batch_size: bool = False, |
| | load_gbs: bool = True, |
| | seed: Optional[int] = None, |
| | use_random_sampler: bool = True, |
| | collate_fn=None, |
| | ): |
| | |
| | common_params = { |
| | "total_samples": len(dataset), |
| | "consumed_samples": consumed_samples, |
| | "micro_batch_size": mbs, |
| | "data_parallel_rank": parallel_state.get_data_parallel_rank(), |
| | "data_parallel_size": parallel_state.get_data_parallel_world_size(), |
| | "drop_last": drop_last, |
| | "global_batch_size": gbs, |
| | "pad_samples_to_global_batch_size": pad_samples_to_global_batch_size, |
| | } |
| |
|
| | if use_random_sampler: |
| | cls = ( |
| | MegatronPretrainingRandomBatchSampler |
| | if load_gbs |
| | else MegatronPretrainingRandomSampler |
| | ) |
| | common_params["seed"] = seed |
| | else: |
| | cls = ( |
| | MegatronPretrainingBatchSampler if load_gbs else MegatronPretrainingSampler |
| | ) |
| | batch_sampler = cls(**common_params) |
| |
|
| | return torch.utils.data.DataLoader( |
| | dataset, |
| | batch_sampler=batch_sampler, |
| | num_workers=num_workers, |
| | pin_memory=True, |
| | collate_fn=collate_fn, |
| | ) |
| |
|
| |
|
| | def load_datasets(cfg: DictConfig) -> tuple[list[dict], list[dict]]: |
| | data_name2num_examples: dict[str, dict[str, int]] = {} |
| | total_train_examples: list[dict] = [] |
| | total_dev_examples: list[dict] = [] |
| | for data_name, data_info in cfg.datasets.items(): |
| | dataset_path: Path = Path(f"{cfg.data_dir}/{data_name}.jsonl") |
| | if not dataset_path.exists(): |
| | raise FileNotFoundError(f"{dataset_path} does not exist.") |
| | if data_info.max_train_samples == 0: |
| | if is_global_rank_zero(): |
| | logging.info( |
| | f"max_train_samples for {data_name} is set to 0. Skip them." |
| | ) |
| | continue |
| |
|
| | if is_global_rank_zero(): |
| | logging.info(f"processing {dataset_path}...") |
| | loaded_examples: list[dict] = [] |
| | with dataset_path.open(encoding="utf-8") as f: |
| | for line in f: |
| | loaded_examples.append(json.loads(line)) |
| |
|
| | if data_info.max_train_samples > len(loaded_examples) and is_global_rank_zero(): |
| | logging.warning( |
| | f"{data_name} has only {len(loaded_examples)} examples, " |
| | f"but max_train_samples is set to {data_info.max_train_samples}. " |
| | "Use all examples." |
| | ) |
| |
|
| | max_train_samples: int = ( |
| | data_info.max_train_samples |
| | if data_info.max_train_samples != -1 |
| | else len(loaded_examples) |
| | ) |
| | max_dev_samples: int = 0 |
| | if data_info.split_dev: |
| | max_dev_samples = min( |
| | cfg.max_dev_samples, |
| | int(len(loaded_examples) * cfg.max_dev_ratio), |
| | ) |
| | train_examples: list[dict] = ( |
| | loaded_examples[max_dev_samples : max_dev_samples + max_train_samples] |
| | * data_info.upsampling_factor |
| | ) |
| | dev_examples: list[dict] = ( |
| | loaded_examples[:max_dev_samples] * data_info.upsampling_factor |
| | ) |
| |
|
| | total_train_examples.extend(train_examples) |
| | total_dev_examples.extend(dev_examples) |
| | data_name2num_examples[data_name] = { |
| | "train": len(train_examples), |
| | "dev": len(dev_examples), |
| | "original": len(loaded_examples), |
| | "upsampling_factor": data_info.upsampling_factor, |
| | } |
| |
|
| | if is_global_rank_zero(): |
| | num_total_original_examples: int = 0 |
| | logging.info("------------------------------") |
| | logging.info("Dataset summary (original -> train/dev)") |
| | for data_name, num_examples in data_name2num_examples.items(): |
| | num_total_original_examples += num_examples["original"] |
| | logging.info( |
| | f"{data_name}: {num_examples['original']} -> {num_examples['train']}/{num_examples['dev']} (upsampling factor: {num_examples['upsampling_factor']})" |
| | ) |
| | logging.info( |
| | f"Total: {num_total_original_examples} -> {len(total_train_examples)}/{len(total_dev_examples)}" |
| | ) |
| | logging.info("------------------------------") |
| |
|
| | return total_train_examples, total_dev_examples |
| |
|
| |
|
| | class LLMJPSFTDataset(Dataset): |
| | def __init__( |
| | self, |
| | loaded_examples: list[dict], |
| | tokenizer: TokenizerSpec, |
| | use_loss_mask: bool, |
| | max_seq_length: int = 4096, |
| | ): |
| | self.tokenizer = tokenizer |
| | self.use_loss_mask: bool = use_loss_mask |
| | self.max_seq_length: int = max_seq_length |
| |
|
| | self.examples: list[dict[str, list[int]]] = self._process_examples( |
| | loaded_examples |
| | ) |
| |
|
| | def __len__(self) -> int: |
| | return len(self.examples) |
| |
|
| | def __getitem__(self, idx: int) -> dict[str, list[int]]: |
| | return self.examples[idx] |
| |
|
| | def _process_examples( |
| | self, loaded_examples: list[dict] |
| | ) -> list[dict[str, list[int]]]: |
| | all_input_ids: list[int] = [] |
| | all_loss_mask: list[int] = [] |
| | for example_idx, loaded_example in enumerate(loaded_examples): |
| | conversation: list[dict[str, str]] = loaded_example["messages"] |
| | assert len(conversation) >= 3 |
| | assert conversation[0]["role"] == "system" |
| |
|
| | input_ids: list[int] = [self.tokenizer.bos_id] + self.tokenizer.text_to_ids( |
| | conversation[0]["content"] |
| | ) |
| | loss_mask: list[int] = ( |
| | [0] * len(input_ids) if self.use_loss_mask else [1] * len(input_ids) |
| | ) |
| | for turn_idx in range(1, len(conversation[1:]) // 2 + 1): |
| | user_message: dict[str, str] = conversation[2 * turn_idx - 1] |
| | assistant_message: dict[str, str] = conversation[2 * turn_idx] |
| | assert user_message["role"] == "user" |
| | assert assistant_message["role"] == "assistant" |
| |
|
| | if self.use_loss_mask: |
| | prompt_ids: list[int] = self.tokenizer.text_to_ids( |
| | f"\n\n### 指示:\n{user_message['content']}\n\n### 応答:\n" |
| | )[1:] |
| | response_ids: list[int] = self.tokenizer.text_to_ids( |
| | f"\n{assistant_message['content']}" |
| | )[2:] + [self.tokenizer.eos_id] |
| | input_ids.extend(prompt_ids + response_ids) |
| | loss_mask.extend([0] * len(prompt_ids) + [1] * len(response_ids)) |
| | else: |
| | prompt_response_ids: list[int] = self.tokenizer.text_to_ids( |
| | f"\n\n### 指示:\n{user_message['content']}\n\n### 応答:\n{assistant_message['content']}" |
| | )[1:] + [self.tokenizer.eos_id] |
| | input_ids.extend(prompt_response_ids) |
| | loss_mask.extend([1] * len(prompt_response_ids)) |
| |
|
| | if is_global_rank_zero() and example_idx < 2: |
| | logging.info(f"{example_idx = }") |
| | logging.info(f"{input_ids = }") |
| | logging.info(f"{loss_mask = }") |
| |
|
| | all_input_ids.extend(input_ids) |
| | all_loss_mask.extend(loss_mask) |
| |
|
| | examples: list[dict[str, list[int]]] = [] |
| | for i in range(0, len(all_input_ids), self.max_seq_length + 1): |
| | chunked_input_ids: list[int] = all_input_ids[ |
| | i : i + self.max_seq_length + 1 |
| | ] |
| | chunked_loss_mask: list[int] = all_loss_mask[ |
| | i : i + self.max_seq_length + 1 |
| | ] |
| | if len(chunked_input_ids) == self.max_seq_length + 1: |
| | if set(chunked_loss_mask) == {0}: |
| | continue |
| | examples.append( |
| | {"input_ids": chunked_input_ids, "loss_mask": chunked_loss_mask} |
| | ) |
| | return examples |
| |
|
| | @torch.no_grad() |
| | def _create_attention_mask(self, seq_length: int) -> torch.Tensor: |
| | attention_mask = torch.tril(torch.ones((seq_length, seq_length))).unsqueeze( |
| | 0 |
| | ) |
| | attention_mask = attention_mask < 0.5 |
| | return attention_mask |
| |
|
| | def collate_fn(self, batch: list[dict[str, list[int]]]) -> dict[str, torch.Tensor]: |
| | input_ids: list[list[int]] = [item["input_ids"][:-1] for item in batch] |
| | labels: list[list[int]] = [item["input_ids"][1:] for item in batch] |
| | loss_mask: list[list[int]] = [item["loss_mask"][1:] for item in batch] |
| |
|
| | pro_batch = { |
| | "tokens": torch.LongTensor(input_ids), |
| | "position_ids": torch.LongTensor( |
| | [list(range(self.max_seq_length)) for _ in batch] |
| | ), |
| | "attention_mask": torch.stack( |
| | [self._create_attention_mask(self.max_seq_length) for _ in batch] |
| | ), |
| | "labels": torch.LongTensor(labels), |
| | "loss_mask": torch.LongTensor(loss_mask), |
| | } |
| |
|
| | return pro_batch |
| |
|