| import torch |
| import numpy as np |
|
|
| def mask_to_rle(mask): |
| """ |
| Convert a binary mask to RLE. |
| |
| Args: |
| mask (torch.Tensor): 1D tensor of binary mask (0s and 1s). |
| |
| Returns: |
| list: RLE as a list of start and length pairs. |
| """ |
| pixels = mask.detach().cpu().numpy().flatten() |
| rle = [] |
| last_val = pixels[0] |
| start = 0 |
| length = 1 |
|
|
| for idx in range(1, len(pixels)): |
| pixel = pixels[idx] |
| if pixel == last_val: |
| length += 1 |
| else: |
| if last_val == 1: |
| rle.extend([start, length]) |
| start = idx |
| length = 1 |
| last_val = pixel |
|
|
| |
| if last_val == 1: |
| rle.extend([start, length]) |
|
|
| return rle |
|
|
| def rle_encode(tensor): |
| """ |
| Encode a batch of binary masks using RLE. |
| |
| Args: |
| tensor (torch.Tensor): 2D tensor of shape [b, n]. |
| |
| Returns: |
| list, list: RLE encoded masks and their lengths. |
| """ |
| rle_batches = [] |
| lengths = [] |
| for batch_item in tensor: |
| rle_code = mask_to_rle(batch_item) |
| rle_batches.append(rle_code) |
| lengths.append(len(rle_code)) |
| return rle_batches, lengths |
|
|
| def rle_to_mask(rle, shape): |
| """ |
| Convert RLE back to binary mask. |
| |
| Args: |
| rle (list): RLE as a list of start and length pairs. |
| shape (tuple): Shape of the original mask (height, width). |
| |
| Returns: |
| numpy.ndarray: Reconstructed binary mask. |
| """ |
| mask = np.zeros(shape[0] * shape[1], dtype=np.uint8) |
|
|
| for i in range(0, len(rle), 2): |
| start = rle[i] |
| length = rle[i + 1] |
| mask[start:start + length] = 1 |
|
|
| return mask.reshape(shape) |
|
|
| def rle_decode(encoded_batches, shape): |
| """ |
| Decode a batch of RLE encoded masks. |
| |
| Args: |
| encoded_batches (list): List of RLE encoded masks. |
| lengths (list): List of lengths of the RLE encoded masks. |
| shape (tuple): Shape of the original mask (b, n). |
| |
| Returns: |
| torch.Tensor: Decoded binary masks. |
| """ |
| decoded_batches = [] |
| for encoded in encoded_batches: |
| decoded = rle_to_mask(encoded, (1, shape[1])) |
| decoded_batches.append(decoded) |
| return torch.tensor(decoded_batches).squeeze(1) |
|
|
|
|
| if __name__ == "__main__": |
| tensor = torch.tensor([[0, 0, 1, 1, 0, 0, 1, 0, 1], |
| [1, 1, 0, 0, 0, 0, 1, 0, 0], |
| [0, 0, 1, 1, 1, 0, 0, 0, 0], |
| [0, 0, 0, 0, 0, 0, 1, 1, 0]]).float() |
| encoded, lengths = rle_encode(tensor) |
| print("Encoded RLE:", encoded) |
| print("Lengths:", lengths) |
|
|
| shape = tensor.shape |
| decoded = rle_decode(encoded, shape) |
| print("Decoded Tensors:\n", decoded.float()) |