US_FiLMUNet / image_processing_film_unet2d.py
Morelli001's picture
Upload folder using huggingface_hub
aee1a39 verified
# image_processing_film_unet2d.py
from typing import List, Union, Tuple, Optional
import numpy as np
from PIL import Image
import torch
from transformers.image_processing_utils import ImageProcessingMixin
ArrayLike = Union[np.ndarray, torch.Tensor, Image.Image]
def _to_rgb_numpy(im: ArrayLike) -> np.ndarray:
# -> float32 HWC in [0,255], 3 channels
if isinstance(im, Image.Image):
if im.mode != "RGB":
im = im.convert("RGB")
arr = np.array(im, dtype=np.uint8).astype(np.float32)
elif isinstance(im, torch.Tensor):
t = im.detach().cpu()
if t.ndim != 3:
raise ValueError("Tensor must be 3D (CHW or HWC).")
if t.shape[0] in (1, 3): # CHW
if t.shape[0] == 1:
t = t.repeat(3, 1, 1)
t = t.permute(1, 2, 0) # HWC
elif t.shape[-1] == 1: # HWC gray
t = t.repeat(1, 1, 3)
arr = t.numpy()
if arr.dtype in (np.float32, np.float64) and arr.max() <= 1.5:
arr = (arr * 255.0).astype(np.float32)
else:
arr = arr.astype(np.float32)
else:
arr = np.array(im)
if arr.ndim == 2:
arr = np.repeat(arr[..., None], 3, axis=-1)
arr = arr.astype(np.float32)
if arr.max() <= 1.5:
arr = (arr * 255.0).astype(np.float32)
if arr.ndim != 3 or arr.shape[-1] != 3:
raise ValueError("Expected RGB image with shape HxWx3.")
return arr
def _letterbox_keep_ratio(arr: np.ndarray, target_hw: Tuple[int, int]):
"""Resize with aspect ratio preserved and pad with 0 (black) to target (H,W).
Returns: out(H,W,3), (top, left, new_h, new_w)
"""
th, tw = target_hw
h, w = arr.shape[:2]
scale = min(th / h, tw / w)
nh, nw = int(round(h * scale)), int(round(w * scale))
if nh <= 0 or nw <= 0:
raise ValueError("Invalid resize result.")
pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8))
pil = pil.resize((nw, nh), resample=Image.BILINEAR)
rs = np.array(pil, dtype=np.float32)
out = np.zeros((th, tw, 3), dtype=np.float32)
top = (th - nh) // 2
left = (tw - nw) // 2
out[top:top+nh, left:left+nw] = rs
return out, (top, left, nh, nw)
def _zscore_ignore_black(chw: np.ndarray, eps: float = 1e-8) -> np.ndarray:
mask = (chw.sum(axis=0) > 0) # HxW
if not mask.any():
return chw.copy()
valid = chw[:, mask]
mean = valid.mean()
std = valid.std()
return (chw - mean) / std if std > eps else (chw - mean)
class FilmUnet2DImageProcessor(ImageProcessingMixin):
"""
Processor for FILMUnet2D:
- Convert to RGB
- Keep-aspect-ratio resize+pad (letterbox) to 512x512 (configurable)
- Normalize with mean/std in 0–255 space (like your training)
- Optional z-score 'self_norm' ignoring black pixels
Returns dict with:
- pixel_values: torch.FloatTensor [B,3,H,W]
- original_sizes: torch.LongTensor [B,2] (H,W)
- letterbox_params: torch.LongTensor [B,4] (top, left, nh, nw) # NEW
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Tuple[int, int] = (512, 512),
keep_ratio: bool = True,
image_mean: Tuple[float, float, float] = (123.675, 116.28, 103.53),
image_std: Tuple[float, float, float] = (58.395, 57.12, 57.375),
self_norm: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.do_resize = bool(do_resize)
self.size = tuple(size)
self.keep_ratio = bool(keep_ratio)
self.image_mean = tuple(float(x) for x in image_mean)
self.image_std = tuple(float(x) for x in image_std)
self.self_norm = bool(self_norm)
def __call__(
self,
images: Union[ArrayLike, List[ArrayLike]],
return_tensors: Optional[str] = "pt",
**kwargs,
):
imgs = images if isinstance(images, (list, tuple)) else [images]
batch = []
orig_sizes = []
lb_params = []
for im in imgs:
arr = _to_rgb_numpy(im) # HWC float32 in 0–255
oh, ow = arr.shape[:2]
orig_sizes.append((oh, ow))
if self.do_resize:
if self.keep_ratio:
arr, meta = _letterbox_keep_ratio(arr, self.size) # meta=(top,left,nh,nw)
else:
h, w = self.size
pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8))
arr = np.array(pil.resize((w, h), resample=Image.BILINEAR), dtype=np.float32)
meta = (0, 0, h, w)
else:
# no resize: still expose meta so postprocess can handle consistently
h, w = arr.shape[:2]
pad_h = self.size[0] - h
pad_w = self.size[1] - w
top = max(pad_h // 2, 0)
left = max(pad_w // 2, 0)
out = np.zeros((*self.size, 3), dtype=np.float32)
out[top:top+h, left:left+w] = arr[:self.size[0]-top, :self.size[1]-left]
arr = out
meta = (top, left, h, w)
lb_params.append(meta)
mean = np.array(self.image_mean, dtype=np.float32).reshape(1, 1, 3)
std = np.array(self.image_std, dtype=np.float32).reshape(1, 1, 3)
arr = (arr - mean) / std # HWC
chw = np.transpose(arr, (2, 0, 1)) # C,H,W
if self.self_norm:
chw = _zscore_ignore_black(chw)
batch.append(chw)
pixel_values = np.stack(batch, axis=0) # B,C,H,W
if return_tensors == "pt":
pixel_values = torch.from_numpy(pixel_values).to(torch.float32)
original_sizes = torch.tensor(orig_sizes, dtype=torch.long)
letterbox_params = torch.tensor(lb_params, dtype=torch.long)
else:
original_sizes = orig_sizes
letterbox_params = lb_params
return {
"pixel_values": pixel_values,
"original_sizes": original_sizes, # (B,2) H,W
"letterbox_params": letterbox_params # (B,4) top,left,nh,nw in 512x512
}
# ---------- POST-PROCESSING ----------
def post_process_semantic_segmentation(
self,
outputs: dict,
processor_inputs: Optional[dict] = None,
threshold: float = 0.5,
return_as_pil: bool = True,
):
"""
Turn model outputs into masks resized back to the ORIGINAL image sizes,
with letterbox padding removed.
Args:
outputs: dict from model forward (expects 'logits': [B,1,512,512])
processor_inputs: the dict returned by __call__ (must contain
'original_sizes' [B,2] and 'letterbox_params' [B,4])
threshold: probability threshold for binarization
return_as_pil: return a list of PIL Images (uint8 0/255) if True,
else a list of torch tensors [H,W] uint8
Returns:
List of masks back in original sizes (H,W).
"""
logits = outputs["logits"] # [B,1,H,W]
probs = torch.sigmoid(logits)
masks = (probs > threshold).to(torch.uint8) * 255 # [B,1,H,W] uint8
if processor_inputs is None:
raise ValueError("processor_inputs must be provided to undo letterboxing.")
orig_sizes = processor_inputs["original_sizes"] # [B,2]
lb_params = processor_inputs["letterbox_params"] # [B,4] top,left,nh,nw
results = []
B = masks.shape[0]
for i in range(B):
m = masks[i, 0] # [512,512]
top, left, nh, nw = [int(x) for x in lb_params[i].tolist()]
# crop letterbox
m_cropped = m[top:top+nh, left:left+nw] # [nh,nw]
# resize back to original
oh, ow = [int(x) for x in orig_sizes[i].tolist()]
m_resized = torch.nn.functional.interpolate(
m_cropped.unsqueeze(0).unsqueeze(0).float(),
size=(oh, ow),
mode="nearest"
)[0,0].to(torch.uint8) # [oh,ow]
if return_as_pil:
results.append(Image.fromarray(m_resized.cpu().numpy(), mode="L"))
else:
results.append(m_resized)
return results