# image_processing_film_unet2d.py from typing import List, Union, Tuple, Optional import numpy as np from PIL import Image import torch from transformers.image_processing_utils import ImageProcessingMixin ArrayLike = Union[np.ndarray, torch.Tensor, Image.Image] def _to_rgb_numpy(im: ArrayLike) -> np.ndarray: # -> float32 HWC in [0,255], 3 channels if isinstance(im, Image.Image): if im.mode != "RGB": im = im.convert("RGB") arr = np.array(im, dtype=np.uint8).astype(np.float32) elif isinstance(im, torch.Tensor): t = im.detach().cpu() if t.ndim != 3: raise ValueError("Tensor must be 3D (CHW or HWC).") if t.shape[0] in (1, 3): # CHW if t.shape[0] == 1: t = t.repeat(3, 1, 1) t = t.permute(1, 2, 0) # HWC elif t.shape[-1] == 1: # HWC gray t = t.repeat(1, 1, 3) arr = t.numpy() if arr.dtype in (np.float32, np.float64) and arr.max() <= 1.5: arr = (arr * 255.0).astype(np.float32) else: arr = arr.astype(np.float32) else: arr = np.array(im) if arr.ndim == 2: arr = np.repeat(arr[..., None], 3, axis=-1) arr = arr.astype(np.float32) if arr.max() <= 1.5: arr = (arr * 255.0).astype(np.float32) if arr.ndim != 3 or arr.shape[-1] != 3: raise ValueError("Expected RGB image with shape HxWx3.") return arr def _letterbox_keep_ratio(arr: np.ndarray, target_hw: Tuple[int, int]): """Resize with aspect ratio preserved and pad with 0 (black) to target (H,W). Returns: out(H,W,3), (top, left, new_h, new_w) """ th, tw = target_hw h, w = arr.shape[:2] scale = min(th / h, tw / w) nh, nw = int(round(h * scale)), int(round(w * scale)) if nh <= 0 or nw <= 0: raise ValueError("Invalid resize result.") pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) pil = pil.resize((nw, nh), resample=Image.BILINEAR) rs = np.array(pil, dtype=np.float32) out = np.zeros((th, tw, 3), dtype=np.float32) top = (th - nh) // 2 left = (tw - nw) // 2 out[top:top+nh, left:left+nw] = rs return out, (top, left, nh, nw) def _zscore_ignore_black(chw: np.ndarray, eps: float = 1e-8) -> np.ndarray: mask = (chw.sum(axis=0) > 0) # HxW if not mask.any(): return chw.copy() valid = chw[:, mask] mean = valid.mean() std = valid.std() return (chw - mean) / std if std > eps else (chw - mean) class FilmUnet2DImageProcessor(ImageProcessingMixin): """ Processor for FILMUnet2D: - Convert to RGB - Keep-aspect-ratio resize+pad (letterbox) to 512x512 (configurable) - Normalize with mean/std in 0–255 space (like your training) - Optional z-score 'self_norm' ignoring black pixels Returns dict with: - pixel_values: torch.FloatTensor [B,3,H,W] - original_sizes: torch.LongTensor [B,2] (H,W) - letterbox_params: torch.LongTensor [B,4] (top, left, nh, nw) # NEW """ model_input_names = ["pixel_values"] def __init__( self, do_resize: bool = True, size: Tuple[int, int] = (512, 512), keep_ratio: bool = True, image_mean: Tuple[float, float, float] = (123.675, 116.28, 103.53), image_std: Tuple[float, float, float] = (58.395, 57.12, 57.375), self_norm: bool = False, **kwargs, ): super().__init__(**kwargs) self.do_resize = bool(do_resize) self.size = tuple(size) self.keep_ratio = bool(keep_ratio) self.image_mean = tuple(float(x) for x in image_mean) self.image_std = tuple(float(x) for x in image_std) self.self_norm = bool(self_norm) def __call__( self, images: Union[ArrayLike, List[ArrayLike]], return_tensors: Optional[str] = "pt", **kwargs, ): imgs = images if isinstance(images, (list, tuple)) else [images] batch = [] orig_sizes = [] lb_params = [] for im in imgs: arr = _to_rgb_numpy(im) # HWC float32 in 0–255 oh, ow = arr.shape[:2] orig_sizes.append((oh, ow)) if self.do_resize: if self.keep_ratio: arr, meta = _letterbox_keep_ratio(arr, self.size) # meta=(top,left,nh,nw) else: h, w = self.size pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8)) arr = np.array(pil.resize((w, h), resample=Image.BILINEAR), dtype=np.float32) meta = (0, 0, h, w) else: # no resize: still expose meta so postprocess can handle consistently h, w = arr.shape[:2] pad_h = self.size[0] - h pad_w = self.size[1] - w top = max(pad_h // 2, 0) left = max(pad_w // 2, 0) out = np.zeros((*self.size, 3), dtype=np.float32) out[top:top+h, left:left+w] = arr[:self.size[0]-top, :self.size[1]-left] arr = out meta = (top, left, h, w) lb_params.append(meta) mean = np.array(self.image_mean, dtype=np.float32).reshape(1, 1, 3) std = np.array(self.image_std, dtype=np.float32).reshape(1, 1, 3) arr = (arr - mean) / std # HWC chw = np.transpose(arr, (2, 0, 1)) # C,H,W if self.self_norm: chw = _zscore_ignore_black(chw) batch.append(chw) pixel_values = np.stack(batch, axis=0) # B,C,H,W if return_tensors == "pt": pixel_values = torch.from_numpy(pixel_values).to(torch.float32) original_sizes = torch.tensor(orig_sizes, dtype=torch.long) letterbox_params = torch.tensor(lb_params, dtype=torch.long) else: original_sizes = orig_sizes letterbox_params = lb_params return { "pixel_values": pixel_values, "original_sizes": original_sizes, # (B,2) H,W "letterbox_params": letterbox_params # (B,4) top,left,nh,nw in 512x512 } # ---------- POST-PROCESSING ---------- def post_process_semantic_segmentation( self, outputs: dict, processor_inputs: Optional[dict] = None, threshold: float = 0.5, return_as_pil: bool = True, ): """ Turn model outputs into masks resized back to the ORIGINAL image sizes, with letterbox padding removed. Args: outputs: dict from model forward (expects 'logits': [B,1,512,512]) processor_inputs: the dict returned by __call__ (must contain 'original_sizes' [B,2] and 'letterbox_params' [B,4]) threshold: probability threshold for binarization return_as_pil: return a list of PIL Images (uint8 0/255) if True, else a list of torch tensors [H,W] uint8 Returns: List of masks back in original sizes (H,W). """ logits = outputs["logits"] # [B,1,H,W] probs = torch.sigmoid(logits) masks = (probs > threshold).to(torch.uint8) * 255 # [B,1,H,W] uint8 if processor_inputs is None: raise ValueError("processor_inputs must be provided to undo letterboxing.") orig_sizes = processor_inputs["original_sizes"] # [B,2] lb_params = processor_inputs["letterbox_params"] # [B,4] top,left,nh,nw results = [] B = masks.shape[0] for i in range(B): m = masks[i, 0] # [512,512] top, left, nh, nw = [int(x) for x in lb_params[i].tolist()] # crop letterbox m_cropped = m[top:top+nh, left:left+nw] # [nh,nw] # resize back to original oh, ow = [int(x) for x in orig_sizes[i].tolist()] m_resized = torch.nn.functional.interpolate( m_cropped.unsqueeze(0).unsqueeze(0).float(), size=(oh, ow), mode="nearest" )[0,0].to(torch.uint8) # [oh,ow] if return_as_pil: results.append(Image.fromarray(m_resized.cpu().numpy(), mode="L")) else: results.append(m_resized) return results