Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageOps | |
| class PreparedImage: | |
| full_pil: Image.Image | |
| model_pil: Image.Image | |
| full_tensor: torch.Tensor | |
| model_tensor: torch.Tensor | |
| def normalize_pil(image: Image.Image) -> Image.Image: | |
| image = ImageOps.exif_transpose(image) | |
| return image.convert("RGB") | |
| def resize_max_side(image: Image.Image, max_side: int) -> Image.Image: | |
| max_side = int(max(256, min(max_side, 2048))) | |
| width, height = image.size | |
| longest = max(width, height) | |
| if longest <= max_side: | |
| return image.copy() | |
| scale = max_side / float(longest) | |
| new_size = (max(1, round(width * scale)), max(1, round(height * scale))) | |
| return image.resize(new_size, Image.Resampling.LANCZOS) | |
| def pil_to_tensor(image: Image.Image) -> torch.Tensor: | |
| arr = np.asarray(image, dtype=np.float32) / 255.0 | |
| tensor = torch.from_numpy(arr).permute(2, 0, 1).contiguous() | |
| return tensor.unsqueeze(0) | |
| def tensor_to_pil(tensor: torch.Tensor) -> Image.Image: | |
| tensor = tensor.detach().float().cpu().clamp(0, 1) | |
| if tensor.ndim == 4: | |
| tensor = tensor[0] | |
| arr = (tensor.permute(1, 2, 0).numpy() * 255.0).round().astype(np.uint8) | |
| return Image.fromarray(arr, mode="RGB") | |
| def prepare_image(image: Image.Image, max_side: int, model_size: int = 512) -> PreparedImage: | |
| full_pil = resize_max_side(normalize_pil(image), max_side) | |
| model_pil = full_pil.resize((model_size, model_size), Image.Resampling.BICUBIC) | |
| return PreparedImage( | |
| full_pil=full_pil, | |
| model_pil=model_pil, | |
| full_tensor=pil_to_tensor(full_pil), | |
| model_tensor=pil_to_tensor(model_pil), | |
| ) | |
| def match_tensor_size(tensor: torch.Tensor, target_hw: Tuple[int, int]) -> torch.Tensor: | |
| if tuple(tensor.shape[-2:]) == tuple(target_hw): | |
| return tensor | |
| return F.interpolate(tensor, size=target_hw, mode="bilinear", align_corners=False) | |
| def blend_strength(input_tensor: torch.Tensor, output_tensor: torch.Tensor, strength: float) -> torch.Tensor: | |
| strength = float(max(0.0, min(2.0, strength))) | |
| output_tensor = match_tensor_size(output_tensor, input_tensor.shape[-2:]) | |
| return (input_tensor + strength * (output_tensor - input_tensor)).clamp(0, 1) | |