| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from typing import Any, Dict |
|
|
|
|
| class DynBszBuffer: |
| """ |
| A buffer to store samples for dynamic batch size. |
| """ |
|
|
| def __init__(self): |
| self._buffer = [] |
| self._buffer_sample_lens = [] |
| self.del_idxs = [] |
| self.cur_idx = 0 |
| self.all_token_cnt = 0 |
|
|
| def append(self, item: Dict[str, Any]): |
| """ |
| Append a sample to the buffer. |
| Args: |
| item: a sample to append to the buffer. |
| The sample should be a dict with the following keys: |
| - input_ids: torch.Tensor of shape (seq_len, ) |
| - attention_mask: torch.Tensor of shape (seq_len, ) |
| """ |
| self._buffer.append(item) |
| if 'attention_mask' in item: |
| self._buffer_sample_lens.append(item["attention_mask"].sum()) |
| self.all_token_cnt += self._buffer_sample_lens[-1] |
| elif 'lang_masks' in item: |
| self._buffer_sample_lens.append(item["lang_masks"].sum()) |
| self.all_token_cnt += self._buffer_sample_lens[-1] |
|
|
| def get_samples(self, n_token_per_iter: int, force: bool = True): |
| """ |
| get samples from the buffer. |
| Args: |
| n_token_per_iter: the number of tokens to get. |
| force: if True, the first sample will be returned even if it is not full. |
| Returns: |
| samples: a list of samples. |
| """ |
| cum_seq_len = 0 |
| samples = [] |
| while self.cur_idx < len(self._buffer) and cum_seq_len < n_token_per_iter: |
| seq_len = self._buffer_sample_lens[self.cur_idx] |
| if self.cur_idx not in self.del_idxs and ( |
| (force is True and cum_seq_len == 0) or (seq_len <= n_token_per_iter - cum_seq_len) |
| ): |
| cum_seq_len += seq_len |
| samples.append(self._buffer[self.cur_idx]) |
| self.del_idxs.append(self.cur_idx) |
| self.cur_idx += 1 |
| assert len(samples) > 0 |
| return samples |
|
|
| def __len__(self): |
| return len(self._buffer) |
|
|
| def flush(self): |
| """ " |
| Flush the buffer. |
| """ |
| self.cur_idx = 0 |
| self.all_token_cnt -= sum([self._buffer_sample_lens[idx] for idx in self.del_idxs]) |
| buffer_len = len(self._buffer) |
| self._buffer = [self._buffer[idx] for idx in range(buffer_len) if idx not in self.del_idxs] |
| self._buffer_sample_lens = [ |
| self._buffer_sample_lens[idx] for idx in range(buffer_len) if idx not in self.del_idxs |
| ] |
| self.del_idxs = [] |
|
|
| def merge(self, buffer_to_merge: "DynBszBuffer"): |
| """ " |
| Merge the buffer with another buffer. |
| Args: |
| buffer_to_merge: the buffer to merge. |
| """ |
| self.flush() |
| buffer_to_merge.flush() |
| for item in buffer_to_merge._buffer: |
| self.append(item) |
|
|
|
|
| class BaseBatchingStrategy: |
| """ |
| Base class for batching strategy.s |
| """ |
|
|
| def is_full_filled(self) -> bool: |
| raise NotImplementedError("should implement `is_full_filled`") |
|
|
| def put_item(self, item: Dict[str, Any]): |
| raise NotImplementedError("should implement `put_item`") |
|
|
| def get_micro_batch(self, step: int) -> Any: |
| raise NotImplementedError("should implement `get_micro_batch` ") |
|
|
| def empty(self) -> bool: |
| raise NotImplementedError("should implement `empty`") |
|
|
|
|
| class IdentityPacker: |
| def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken): |
| self.token_micro_bsz = token_micro_bsz |
| self.bsz_warmup_steps = bsz_warmup_steps |
| self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken |
|
|
| def __call__(self, samples): |
| return samples |
|
|
| def get_token_num_to_request(self, cur_step, warmup): |
| return ( |
| (self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps |
| + self.bsz_warmup_init_mbtoken |
| if warmup |
| else self.token_micro_bsz |
| ) |
|
|
|
|
| class TextBatchingStrategy(BaseBatchingStrategy): |
| """ " |
| Batching strategy for text data. |
| Args: |
| token_micro_bsz: the number of tokens to get for each request. |
| buffer_size: the size of the buffer. |
| bsz_warmup_steps: the number of steps to warm up the batch size. |
| bsz_warmup_init_mbtoken: the initial number of tokens to get for each request. |
| """ |
|
|
| def __init__( |
| self, |
| token_micro_bsz, |
| buffer_size: int = 500, |
| bsz_warmup_steps: int = -1, |
| bsz_warmup_init_mbtoken: int = 200, |
| ) -> None: |
| super().__init__() |
| self._step = 0 |
| self.token_micro_bsz = token_micro_bsz |
| self.bsz_warmup_steps = bsz_warmup_steps |
| self.buffer_size = buffer_size |
| self.buffer = DynBszBuffer() |
| self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken |
| assert self.bsz_warmup_init_mbtoken >= 0 |
|
|
| self.packer = IdentityPacker( |
| token_micro_bsz=token_micro_bsz, |
| bsz_warmup_steps=bsz_warmup_steps, |
| bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken, |
| ) |
|
|
| def is_full_filled(self) -> bool: |
| return len(self.buffer) >= self.buffer_size and self.buffer.all_token_cnt >= self.token_micro_bsz |
|
|
| def put_item(self, item: Dict[str, Any]): |
| if "input_ids" in item: |
| if len(item["input_ids"]) == 1: |
| print("WARNING: EMPTY STRING.") |
| return |
| elif "lang_tokens" in item: |
| if all (item["lang_tokens"] == 0): |
| print("WARNING: EMPTY STRING.") |
| return |
| self.buffer.append(item) |
|
|
| def get_token_num_to_request(self): |
| if self.packer is not None: |
| warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0 |
| return self.packer.get_token_num_to_request(self._step, warmup=warmup) |
| else: |
| return self.get_cur_token_micro_bsz() |
|
|
| def get_cur_token_micro_bsz(self): |
| warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0 |
| if warmup: |
| return ( |
| self.token_micro_bsz - self.bsz_warmup_init_mbtoken |
| ) * self._step // self.bsz_warmup_steps + self.bsz_warmup_init_mbtoken |
| else: |
| return self.token_micro_bsz |
|
|
| def get_micro_batch(self, step) -> Any: |
| """ |
| Get a micro batch from the buffer according to the current step. |
| Args: |
| step: the current step. |
| Returns: |
| data: a list of samples. |
| """ |
|
|
| self._step = step |
| n_token_per_iter = self.get_token_num_to_request() |
| cur_token_micro_bsz = self.get_cur_token_micro_bsz() |
| assert cur_token_micro_bsz % n_token_per_iter == 0, ( |
| "The token num to get for each request should be divisible by token micro bsz." |
| ) |
| n_iter = int(cur_token_micro_bsz // n_token_per_iter) |
| data = [] |
| for i in range(n_iter): |
| samples = self.buffer.get_samples(n_token_per_iter) |
| if self.packer: |
| samples = self.packer(samples) |
| data.extend(samples) |
| self.buffer.flush() |
| return data |
|
|
| def empty(self) -> bool: |
| return len(self.buffer) == 0 |
|
|