Spaces:
Running
Running
| """ | |
| Stream Compaction (Filter) | |
| Removes elements that don't satisfy a predicate, compacting the result. | |
| Also known as filtering or partition. | |
| Example: Remove all zeros from array. | |
| Optimization opportunities: | |
| - Scan-based compaction | |
| - Warp-level ballot for predicate evaluation | |
| - Per-block compaction + global gather | |
| - Decoupled lookback for single-pass | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| class Model(nn.Module): | |
| """ | |
| Stream compaction - removes elements not satisfying predicate. | |
| """ | |
| def __init__(self, threshold: float = 0.5): | |
| super(Model, self).__init__() | |
| self.threshold = threshold | |
| def forward(self, input: torch.Tensor) -> tuple: | |
| """ | |
| Compact array keeping only elements >= threshold. | |
| Args: | |
| input: (N,) input array | |
| Returns: | |
| output: (M,) compacted array (M <= N) | |
| count: number of elements kept | |
| """ | |
| mask = input >= self.threshold | |
| output = input[mask] | |
| count = mask.sum() | |
| return output, count | |
| # Problem configuration | |
| array_size = 16 * 1024 * 1024 | |
| def get_inputs(): | |
| data = torch.rand(array_size) | |
| return [data] | |
| def get_init_inputs(): | |
| return [0.5] # threshold | |