mimc_rl / util /rle.py
wangyanhui666's picture
fine tune decoder with mask
9cf79cf
import torch
import numpy as np
def mask_to_rle(mask):
"""
Convert a binary mask to RLE.
Args:
mask (torch.Tensor): 1D tensor of binary mask (0s and 1s).
Returns:
list: RLE as a list of start and length pairs.
"""
pixels = mask.detach().cpu().numpy().flatten()
rle = []
last_val = pixels[0]
start = 0
length = 1
for idx in range(1, len(pixels)):
pixel = pixels[idx]
if pixel == last_val:
length += 1
else:
if last_val == 1:
rle.extend([start, length])
start = idx
length = 1
last_val = pixel
# Handle the last run
if last_val == 1:
rle.extend([start, length])
return rle
def rle_encode(tensor):
"""
Encode a batch of binary masks using RLE.
Args:
tensor (torch.Tensor): 2D tensor of shape [b, n].
Returns:
list, list: RLE encoded masks and their lengths.
"""
rle_batches = []
lengths = []
for batch_item in tensor:
rle_code = mask_to_rle(batch_item)
rle_batches.append(rle_code)
lengths.append(len(rle_code))
return rle_batches, lengths
def rle_to_mask(rle, shape):
"""
Convert RLE back to binary mask.
Args:
rle (list): RLE as a list of start and length pairs.
shape (tuple): Shape of the original mask (height, width).
Returns:
numpy.ndarray: Reconstructed binary mask.
"""
mask = np.zeros(shape[0] * shape[1], dtype=np.uint8)
for i in range(0, len(rle), 2):
start = rle[i]
length = rle[i + 1]
mask[start:start + length] = 1
return mask.reshape(shape)
def rle_decode(encoded_batches, shape):
"""
Decode a batch of RLE encoded masks.
Args:
encoded_batches (list): List of RLE encoded masks.
lengths (list): List of lengths of the RLE encoded masks.
shape (tuple): Shape of the original mask (b, n).
Returns:
torch.Tensor: Decoded binary masks.
"""
decoded_batches = []
for encoded in encoded_batches:
decoded = rle_to_mask(encoded, (1, shape[1])) # 恢复每个mask的形状
decoded_batches.append(decoded)
return torch.tensor(decoded_batches).squeeze(1)
if __name__ == "__main__":
tensor = torch.tensor([[0, 0, 1, 1, 0, 0, 1, 0, 1],
[1, 1, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 0]]).float()
encoded, lengths = rle_encode(tensor)
print("Encoded RLE:", encoded)
print("Lengths:", lengths)
shape = tensor.shape # shape = (b, n)
decoded = rle_decode(encoded, shape)
print("Decoded Tensors:\n", decoded.float())