|
|
|
|
|
from typing import List, Union, Tuple, Optional |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
from transformers.image_processing_utils import ImageProcessingMixin |
|
|
|
|
|
ArrayLike = Union[np.ndarray, torch.Tensor, Image.Image] |
|
|
|
|
|
def _to_rgb_numpy(im: ArrayLike) -> np.ndarray: |
|
|
|
|
|
if isinstance(im, Image.Image): |
|
|
if im.mode != "RGB": |
|
|
im = im.convert("RGB") |
|
|
arr = np.array(im, dtype=np.uint8).astype(np.float32) |
|
|
elif isinstance(im, torch.Tensor): |
|
|
t = im.detach().cpu() |
|
|
if t.ndim != 3: |
|
|
raise ValueError("Tensor must be 3D (CHW or HWC).") |
|
|
if t.shape[0] in (1, 3): |
|
|
if t.shape[0] == 1: |
|
|
t = t.repeat(3, 1, 1) |
|
|
t = t.permute(1, 2, 0) |
|
|
elif t.shape[-1] == 1: |
|
|
t = t.repeat(1, 1, 3) |
|
|
arr = t.numpy() |
|
|
if arr.dtype in (np.float32, np.float64) and arr.max() <= 1.5: |
|
|
arr = (arr * 255.0).astype(np.float32) |
|
|
else: |
|
|
arr = arr.astype(np.float32) |
|
|
else: |
|
|
arr = np.array(im) |
|
|
if arr.ndim == 2: |
|
|
arr = np.repeat(arr[..., None], 3, axis=-1) |
|
|
arr = arr.astype(np.float32) |
|
|
if arr.max() <= 1.5: |
|
|
arr = (arr * 255.0).astype(np.float32) |
|
|
if arr.ndim != 3 or arr.shape[-1] != 3: |
|
|
raise ValueError("Expected RGB image with shape HxWx3.") |
|
|
return arr |
|
|
|
|
|
def _letterbox_keep_ratio(arr: np.ndarray, target_hw: Tuple[int, int]): |
|
|
"""Resize with aspect ratio preserved and pad with 0 (black) to target (H,W). |
|
|
Returns: out(H,W,3), (top, left, new_h, new_w) |
|
|
""" |
|
|
th, tw = target_hw |
|
|
h, w = arr.shape[:2] |
|
|
scale = min(th / h, tw / w) |
|
|
nh, nw = int(round(h * scale)), int(round(w * scale)) |
|
|
if nh <= 0 or nw <= 0: |
|
|
raise ValueError("Invalid resize result.") |
|
|
pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) |
|
|
pil = pil.resize((nw, nh), resample=Image.BILINEAR) |
|
|
rs = np.array(pil, dtype=np.float32) |
|
|
out = np.zeros((th, tw, 3), dtype=np.float32) |
|
|
top = (th - nh) // 2 |
|
|
left = (tw - nw) // 2 |
|
|
out[top:top+nh, left:left+nw] = rs |
|
|
return out, (top, left, nh, nw) |
|
|
|
|
|
def _zscore_ignore_black(chw: np.ndarray, eps: float = 1e-8) -> np.ndarray: |
|
|
mask = (chw.sum(axis=0) > 0) |
|
|
if not mask.any(): |
|
|
return chw.copy() |
|
|
valid = chw[:, mask] |
|
|
mean = valid.mean() |
|
|
std = valid.std() |
|
|
return (chw - mean) / std if std > eps else (chw - mean) |
|
|
|
|
|
class FilmUnet2DImageProcessor(ImageProcessingMixin): |
|
|
""" |
|
|
Processor for FILMUnet2D: |
|
|
- Convert to RGB |
|
|
- Keep-aspect-ratio resize+pad (letterbox) to 512x512 (configurable) |
|
|
- Normalize with mean/std in 0–255 space (like your training) |
|
|
- Optional z-score 'self_norm' ignoring black pixels |
|
|
Returns dict with: |
|
|
- pixel_values: torch.FloatTensor [B,3,H,W] |
|
|
- original_sizes: torch.LongTensor [B,2] (H,W) |
|
|
- letterbox_params: torch.LongTensor [B,4] (top, left, nh, nw) # NEW |
|
|
""" |
|
|
|
|
|
model_input_names = ["pixel_values"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
do_resize: bool = True, |
|
|
size: Tuple[int, int] = (512, 512), |
|
|
keep_ratio: bool = True, |
|
|
image_mean: Tuple[float, float, float] = (123.675, 116.28, 103.53), |
|
|
image_std: Tuple[float, float, float] = (58.395, 57.12, 57.375), |
|
|
self_norm: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.do_resize = bool(do_resize) |
|
|
self.size = tuple(size) |
|
|
self.keep_ratio = bool(keep_ratio) |
|
|
self.image_mean = tuple(float(x) for x in image_mean) |
|
|
self.image_std = tuple(float(x) for x in image_std) |
|
|
self.self_norm = bool(self_norm) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
images: Union[ArrayLike, List[ArrayLike]], |
|
|
return_tensors: Optional[str] = "pt", |
|
|
**kwargs, |
|
|
): |
|
|
imgs = images if isinstance(images, (list, tuple)) else [images] |
|
|
batch = [] |
|
|
orig_sizes = [] |
|
|
lb_params = [] |
|
|
|
|
|
for im in imgs: |
|
|
arr = _to_rgb_numpy(im) |
|
|
oh, ow = arr.shape[:2] |
|
|
orig_sizes.append((oh, ow)) |
|
|
|
|
|
if self.do_resize: |
|
|
if self.keep_ratio: |
|
|
arr, meta = _letterbox_keep_ratio(arr, self.size) |
|
|
else: |
|
|
h, w = self.size |
|
|
pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) |
|
|
arr = np.array(pil.resize((w, h), resample=Image.BILINEAR), dtype=np.float32) |
|
|
meta = (0, 0, h, w) |
|
|
else: |
|
|
|
|
|
h, w = arr.shape[:2] |
|
|
pad_h = self.size[0] - h |
|
|
pad_w = self.size[1] - w |
|
|
top = max(pad_h // 2, 0) |
|
|
left = max(pad_w // 2, 0) |
|
|
out = np.zeros((*self.size, 3), dtype=np.float32) |
|
|
out[top:top+h, left:left+w] = arr[:self.size[0]-top, :self.size[1]-left] |
|
|
arr = out |
|
|
meta = (top, left, h, w) |
|
|
|
|
|
lb_params.append(meta) |
|
|
|
|
|
mean = np.array(self.image_mean, dtype=np.float32).reshape(1, 1, 3) |
|
|
std = np.array(self.image_std, dtype=np.float32).reshape(1, 1, 3) |
|
|
arr = (arr - mean) / std |
|
|
|
|
|
chw = np.transpose(arr, (2, 0, 1)) |
|
|
if self.self_norm: |
|
|
chw = _zscore_ignore_black(chw) |
|
|
batch.append(chw) |
|
|
|
|
|
pixel_values = np.stack(batch, axis=0) |
|
|
if return_tensors == "pt": |
|
|
pixel_values = torch.from_numpy(pixel_values).to(torch.float32) |
|
|
original_sizes = torch.tensor(orig_sizes, dtype=torch.long) |
|
|
letterbox_params = torch.tensor(lb_params, dtype=torch.long) |
|
|
else: |
|
|
original_sizes = orig_sizes |
|
|
letterbox_params = lb_params |
|
|
|
|
|
return { |
|
|
"pixel_values": pixel_values, |
|
|
"original_sizes": original_sizes, |
|
|
"letterbox_params": letterbox_params |
|
|
} |
|
|
|
|
|
|
|
|
def post_process_semantic_segmentation( |
|
|
self, |
|
|
outputs: dict, |
|
|
processor_inputs: Optional[dict] = None, |
|
|
threshold: float = 0.5, |
|
|
return_as_pil: bool = True, |
|
|
): |
|
|
""" |
|
|
Turn model outputs into masks resized back to the ORIGINAL image sizes, |
|
|
with letterbox padding removed. |
|
|
|
|
|
Args: |
|
|
outputs: dict from model forward (expects 'logits': [B,1,512,512]) |
|
|
processor_inputs: the dict returned by __call__ (must contain |
|
|
'original_sizes' [B,2] and 'letterbox_params' [B,4]) |
|
|
threshold: probability threshold for binarization |
|
|
return_as_pil: return a list of PIL Images (uint8 0/255) if True, |
|
|
else a list of torch tensors [H,W] uint8 |
|
|
|
|
|
Returns: |
|
|
List of masks back in original sizes (H,W). |
|
|
""" |
|
|
logits = outputs["logits"] |
|
|
probs = torch.sigmoid(logits) |
|
|
masks = (probs > threshold).to(torch.uint8) * 255 |
|
|
|
|
|
if processor_inputs is None: |
|
|
raise ValueError("processor_inputs must be provided to undo letterboxing.") |
|
|
|
|
|
orig_sizes = processor_inputs["original_sizes"] |
|
|
lb_params = processor_inputs["letterbox_params"] |
|
|
|
|
|
results = [] |
|
|
B = masks.shape[0] |
|
|
for i in range(B): |
|
|
m = masks[i, 0] |
|
|
top, left, nh, nw = [int(x) for x in lb_params[i].tolist()] |
|
|
|
|
|
m_cropped = m[top:top+nh, left:left+nw] |
|
|
|
|
|
oh, ow = [int(x) for x in orig_sizes[i].tolist()] |
|
|
m_resized = torch.nn.functional.interpolate( |
|
|
m_cropped.unsqueeze(0).unsqueeze(0).float(), |
|
|
size=(oh, ow), |
|
|
mode="nearest" |
|
|
)[0,0].to(torch.uint8) |
|
|
|
|
|
if return_as_pil: |
|
|
results.append(Image.fromarray(m_resized.cpu().numpy(), mode="L")) |
|
|
else: |
|
|
results.append(m_resized) |
|
|
|
|
|
return results |
|
|
|