MVANet / image_processing_mvanet.py
shunk031's picture
Upload processor
f685c19 verified
"""Image processor for MVANet model."""
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import BaseImageProcessor
from transformers.image_processing_utils import BatchFeature
from transformers.image_utils import (
ImageInput,
PILImageResampling,
)
from transformers.utils import TensorType
def to_pil_image(image: Union[np.ndarray, torch.Tensor, Image.Image]) -> Image.Image:
"""Convert various image formats to PIL Image."""
if isinstance(image, Image.Image):
return image
if isinstance(image, torch.Tensor):
# (C, H, W) tensor
if image.ndim == 3 and image.shape[0] in [1, 3, 4]:
image = image.permute(1, 2, 0).cpu().numpy()
image = (image * 255).clip(0, 255).astype(np.uint8)
if isinstance(image, np.ndarray):
if image.ndim == 2:
# Grayscale
return Image.fromarray(image, mode="L")
elif image.ndim == 3:
if image.shape[2] == 1:
return Image.fromarray(image.squeeze(2), mode="L")
elif image.shape[2] == 3:
return Image.fromarray(image, mode="RGB")
elif image.shape[2] == 4:
return Image.fromarray(image, mode="RGBA")
raise ValueError(f"Unsupported image type: {type(image)}")
class MVANetImageProcessor(BaseImageProcessor):
"""
Constructs a MVANet image processor.
Args:
do_resize (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to resize the image.
size (:obj:`Dict[str, int]`, `optional`, defaults to :obj:`{"height": 1024, "width": 1024}`):
Target size for resizing. MVANet was trained on 1024x1024 images.
resample (:obj:`PILImageResampling`, `optional`, defaults to :obj:`PILImageResampling.BILINEAR`):
Resampling filter to use when resizing the image.
do_normalize (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to normalize the image.
image_mean (:obj:`List[float]`, `optional`, defaults to :obj:`[0.485, 0.456, 0.406]`):
Mean to use for normalization (ImageNet mean).
image_std (:obj:`List[float]`, `optional`, defaults to :obj:`[0.229, 0.224, 0.225]`):
Standard deviation to use for normalization (ImageNet std).
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_normalize: bool = True,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
**kwargs,
):
super().__init__(**kwargs)
size = size if size is not None else {"height": 1024, "width": 1024}
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_normalize = do_normalize
self.image_mean = (
image_mean if image_mean is not None else [0.485, 0.456, 0.406]
)
self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225]
def resize(
self,
image: Image.Image,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
) -> Image.Image:
"""Resize image to target size."""
target_height = size["height"]
target_width = size["width"]
return image.resize((target_width, target_height), resample)
def normalize(
self,
image: np.ndarray,
mean: List[float],
std: List[float],
) -> np.ndarray:
"""Normalize image with mean and std."""
image = image.astype(np.float32) / 255.0
mean = np.array(mean, dtype=np.float32)
std = np.array(std, dtype=np.float32)
image = (image - mean) / std
return image
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional[PILImageResampling] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchFeature:
"""
Preprocess images for MVANet.
Args:
images (:obj:`ImageInput`):
Images to preprocess. Can be a single image or a batch of images.
do_resize (:obj:`bool`, `optional`):
Whether to resize the image(s). Defaults to :obj:`self.do_resize`.
size (:obj:`Dict[str, int]`, `optional`):
Target size for resizing. Defaults to :obj:`self.size`.
resample (:obj:`PILImageResampling`, `optional`):
Resampling filter to use. Defaults to :obj:`self.resample`.
do_normalize (:obj:`bool`, `optional`):
Whether to normalize the image(s). Defaults to :obj:`self.do_normalize`.
image_mean (:obj:`List[float]`, `optional`):
Mean for normalization. Defaults to :obj:`self.image_mean`.
image_std (:obj:`List[float]`, `optional`):
Std for normalization. Defaults to :obj:`self.image_std`.
return_tensors (:obj:`str` or :obj:`TensorType`, `optional`):
Type of tensors to return. Can be 'pt' for PyTorch.
Returns:
:obj:`BatchFeature`: A :obj:`BatchFeature` with the following fields:
- pixel_values (:obj:`torch.Tensor`): Preprocessed images.
"""
# Set defaults
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
resample = resample if resample is not None else self.resample
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
# Convert to list if single image
if not isinstance(images, list):
images = [images]
# Convert to PIL Images
pil_images = []
# original_sizes = []
for img in images:
pil_img = to_pil_image(img)
# Convert to RGB if not already
if pil_img.mode != "RGB":
pil_img = pil_img.convert("RGB")
# original_sizes.append(pil_img.size) # (width, height)
pil_images.append(pil_img)
# Resize
if do_resize:
pil_images = [self.resize(img, size, resample) for img in pil_images]
# Convert to numpy arrays (H, W, C)
np_images = [np.array(img) for img in pil_images]
# Normalize
if do_normalize:
np_images = [
self.normalize(img, image_mean, image_std) for img in np_images
]
# Convert to (C, H, W) format
np_images = [img.transpose(2, 0, 1) for img in np_images]
# Convert to tensors
if return_tensors == "pt":
pixel_values = torch.tensor(np.stack(np_images), dtype=torch.float32)
else:
pixel_values = np.stack(np_images)
# Store original sizes as metadata (for post-processing)
data = {
"pixel_values": pixel_values,
# "original_sizes": original_sizes, # List of (width, height) tuples
}
return BatchFeature(data=data, tensor_type=return_tensors)
def post_process_semantic_segmentation(
self,
outputs,
target_sizes: Optional[List[Tuple[int, int]]] = None,
) -> List[torch.Tensor]:
"""
Post-process model outputs to semantic segmentation masks.
Args:
outputs (:obj:`SemanticSegmenterOutput` or :obj:`torch.Tensor`):
Model outputs containing logits.
target_sizes (:obj:`List[Tuple[int, int]]`, `optional`):
List of target sizes (width, height) for each image.
If not provided, returns masks at model output size.
Returns:
:obj:`List[torch.Tensor]`: List of segmentation masks (values in [0, 1]).
"""
# Extract logits from outputs
if hasattr(outputs, "logits"):
logits = outputs.logits
else:
logits = outputs
# Apply sigmoid to get probabilities
probs = torch.sigmoid(logits) # (B, 1, H, W)
# Resize to target sizes if provided
if target_sizes is not None:
masks = []
for i, (target_w, target_h) in enumerate(target_sizes):
mask = F.interpolate(
probs[i : i + 1],
size=(target_h, target_w),
mode="bilinear",
align_corners=False,
)
masks.append(mask.squeeze(0).squeeze(0)) # (H, W)
return masks
# Return at original size
return [
probs[i].squeeze(0) for i in range(probs.shape[0])
] # List of (1, H, W) or (H, W)