| from typing import Generator, Iterable, List, TypeVar, Union |
|
|
| B = TypeVar("B") |
|
|
|
|
| def calculate_input_elements(input_value: Union[B, List[B]]) -> int: |
| return len(input_value) if issubclass(type(input_value), list) else 1 |
|
|
|
|
| def create_batches( |
| sequence: Iterable[B], batch_size: int |
| ) -> Generator[List[B], None, None]: |
| batch_size = max(batch_size, 1) |
| current_batch = [] |
| for element in sequence: |
| if len(current_batch) == batch_size: |
| yield current_batch |
| current_batch = [] |
| current_batch.append(element) |
| if len(current_batch) > 0: |
| yield current_batch |
|
|