VisualSemSeg / utils /imageHandling.py
Nunzio
fixed errors
ff83735
raw
history blame
3.01 kB
import torch, torchvision
# %% image loading
def hfImageToTensor(image, width:int=1024, height:int=512)->torch.Tensor:
"""
Convert an input image (PIL.Image or numpy array) from Hugging Face to a torch tensor
of shape (3, height, width) and type float32.
Args:
image: Input image (PIL.Image or numpy array).
width (int): Target width.
height (int): Target height.
Returns:
torch.Tensor: Image tensor of shape (3, height, width).
"""
image = image if isinstance(image, torch.Tensor) else torchvision.transforms.functional.to_tensor(image)
return torchvision.transforms.functional.resize(image, [height, width])
# %% preprocessing
def preprocessing(image_tensor: torch.Tensor) -> torch.Tensor:
"""
Standardize the image tensor and add batch dimension.
Args:
image_tensor (torch.Tensor): Image tensor of shape (3, H, W).
Returns:
torch.Tensor: Preprocessed tensor of shape (1, 3, H, W).
"""
return torchvision.transforms.functional.normalize(
image_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
).unsqueeze(0)
# %% print mask on a sem seg style
def print_mask(mask:torch.Tensor, numClasses:int=19)->None:
"""
Visualizes the segmentation mask by mapping each class to a specific color.
Args:
mask (torch.Tensor): The segmentation mask to visualize.
numClasses (int, optional): Number of classes in the segmentation mask. Defaults to 19.
"""
colors = [
(128, 64, 128), # 0: road
(244, 35, 232), # 1: sidewalk
(70, 70, 70), # 2: building
(102, 102, 156), # 3: wall
(190, 153, 153), # 4: fence
(153, 153, 153), # 5: pole
(250, 170, 30), # 6: traffic light
(220, 220, 0), # 7: traffic sign
(107, 142, 35), # 8: vegetation
(152, 251, 152), # 9: terrain
(70, 130, 180), # 10: sky
(220, 20, 60), # 11: person
(255, 0, 0), # 12: rider
(0, 0, 142), # 13: car
(0, 0, 70), # 14: truck
(0, 60, 100), # 15: bus
(0, 80, 100), # 16: train
(0, 0, 230), # 17: motorcycle
(119, 11, 32) # 18: bicycle
]
new_mask = torch.zeros((mask.shape[0], mask.shape[1], 3), dtype=torch.uint8)
new_mask[mask == 255] = torch.tensor([0, 0, 0], dtype=torch.uint8)
for i in range (numClasses):
new_mask[mask == i] = torch.tensor(colors[i][:3], dtype=torch.uint8)
return new_mask.permute(2,0,1)
# %% postprocessing
def postprocessing(pred: torch.Tensor) -> torch.Tensor:
"""
Convert the model's output tensor to a format suitable for visualization.
Args:
pred (torch.Tensor): Model output tensor of shape (1, H, W).
Returns:
torch.Tensor: Processed tensor of shape (3, H, W) for visualization.
"""
return torchvision.transforms.functional.to_pil_image(print_mask(pred.squeeze(0).cpu().to(torch.uint8)))