iimmortall's picture
Deploy InstantRetouch BILA ZeroGPU Space
bc275c2 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image, ImageOps
@dataclass(frozen=True)
class PreparedImage:
full_pil: Image.Image
model_pil: Image.Image
full_tensor: torch.Tensor
model_tensor: torch.Tensor
def normalize_pil(image: Image.Image) -> Image.Image:
image = ImageOps.exif_transpose(image)
return image.convert("RGB")
def resize_max_side(image: Image.Image, max_side: int) -> Image.Image:
max_side = int(max(256, min(max_side, 2048)))
width, height = image.size
longest = max(width, height)
if longest <= max_side:
return image.copy()
scale = max_side / float(longest)
new_size = (max(1, round(width * scale)), max(1, round(height * scale)))
return image.resize(new_size, Image.Resampling.LANCZOS)
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
arr = np.asarray(image, dtype=np.float32) / 255.0
tensor = torch.from_numpy(arr).permute(2, 0, 1).contiguous()
return tensor.unsqueeze(0)
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
tensor = tensor.detach().float().cpu().clamp(0, 1)
if tensor.ndim == 4:
tensor = tensor[0]
arr = (tensor.permute(1, 2, 0).numpy() * 255.0).round().astype(np.uint8)
return Image.fromarray(arr, mode="RGB")
def prepare_image(image: Image.Image, max_side: int, model_size: int = 512) -> PreparedImage:
full_pil = resize_max_side(normalize_pil(image), max_side)
model_pil = full_pil.resize((model_size, model_size), Image.Resampling.BICUBIC)
return PreparedImage(
full_pil=full_pil,
model_pil=model_pil,
full_tensor=pil_to_tensor(full_pil),
model_tensor=pil_to_tensor(model_pil),
)
def match_tensor_size(tensor: torch.Tensor, target_hw: Tuple[int, int]) -> torch.Tensor:
if tuple(tensor.shape[-2:]) == tuple(target_hw):
return tensor
return F.interpolate(tensor, size=target_hw, mode="bilinear", align_corners=False)
def blend_strength(input_tensor: torch.Tensor, output_tensor: torch.Tensor, strength: float) -> torch.Tensor:
strength = float(max(0.0, min(2.0, strength)))
output_tensor = match_tensor_size(output_tensor, input_tensor.shape[-2:])
return (input_tensor + strength * (output_tensor - input_tensor)).clamp(0, 1)