| import base64 |
| import io |
| import json |
| from pathlib import Path |
| from typing import Dict, Optional |
|
|
| import cv2 |
| import psutil |
| from PIL import Image |
| from loguru import logger |
| from rich.console import Console |
| from rich.progress import ( |
| Progress, |
| SpinnerColumn, |
| TimeElapsedColumn, |
| MofNCompleteColumn, |
| TextColumn, |
| BarColumn, |
| TaskProgressColumn, |
| ) |
|
|
| from iopaint.helper import pil_to_bytes_single |
| from iopaint.model.utils import torch_gc |
| from iopaint.model_manager import ModelManager |
| from iopaint.schema import InpaintRequest |
| import numpy as np |
|
|
|
|
| def glob_images(path: Path) -> Dict[str, Path]: |
| |
| if path.is_file(): |
| return {path.stem: path} |
| elif path.is_dir(): |
| res = {} |
| for it in path.glob("*.*"): |
| if it.suffix.lower() in [".png", ".jpg", ".jpeg"]: |
| res[it.stem] = it |
| return res |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def batch_inpaint( |
| model: str, |
| device, |
| input_base64: str, |
| mask_base64: str, |
| config_base64: Optional[str] = None, |
| concat: bool = False, |
| ): |
| if config_base64 is None: |
| inpaint_request = InpaintRequest() |
| else: |
| config_json = base64.b64decode(config_base64) |
| inpaint_request = InpaintRequest(**json.loads(config_json)) |
|
|
| model_manager = ModelManager(name=model, device=device) |
|
|
| |
| input_image_data = base64.b64decode(input_base64) |
| input_image = cv2.imdecode(np.frombuffer(input_image_data, np.uint8), cv2.IMREAD_COLOR) |
|
|
| |
| mask_image_data = base64.b64decode(mask_base64) |
| mask_image = cv2.imdecode(np.frombuffer(mask_image_data, np.uint8), cv2.IMREAD_GRAYSCALE) |
|
|
| if mask_image.shape[:2] != input_image.shape[:2]: |
| mask_image = cv2.resize( |
| mask_image, |
| (input_image.shape[1], input_image.shape[0]), |
| interpolation=cv2.INTER_NEAREST, |
| ) |
|
|
| mask_image[mask_image >= 127] = 255 |
| mask_image[mask_image < 127] = 0 |
|
|
| |
| inpaint_result = model_manager(input_image, mask_image, inpaint_request) |
|
|
| if concat: |
| mask_image = cv2.cvtColor(mask_image, cv2.COLOR_GRAY2RGB) |
| inpaint_result = cv2.hconcat([input_image, mask_image, inpaint_result]) |
|
|
| |
| pil_image = Image.fromarray(inpaint_result) |
|
|
| |
| with io.BytesIO() as output_buffer: |
| pil_image.save(output_buffer, format='PNG') |
| base64_image = base64.b64encode(output_buffer.getvalue()).decode('utf-8') |
|
|
| return base64_image |
|
|
|
|
| def batch_inpaint_cv2( |
| model: str, |
| device, |
| input_base: str, |
| mask_base: str, |
| config_base64: Optional[str] = None, |
| concat: bool = False, |
| ): |
| if config_base64 is None: |
| inpaint_request = InpaintRequest() |
| else: |
| config_json = base64.b64decode(config_base64) |
| inpaint_request = InpaintRequest(**json.loads(config_json)) |
|
|
| model_manager = ModelManager(name=model, device=device) |
|
|
| |
| input_image = input_base |
| |
| mask_image = mask_base |
|
|
| if mask_image.shape[:2] != input_image.shape[:2]: |
| mask_image = cv2.resize( |
| mask_image, |
| (input_image.shape[1], input_image.shape[0]), |
| interpolation=cv2.INTER_NEAREST, |
| ) |
|
|
| mask_image[mask_image >= 127] = 255 |
| mask_image[mask_image < 127] = 0 |
|
|
| |
| inpaint_result = model_manager(input_image, mask_image, inpaint_request) |
|
|
| if concat: |
| mask_image = cv2.cvtColor(mask_image, cv2.COLOR_GRAY2RGB) |
| inpaint_result = cv2.hconcat([input_image, mask_image, inpaint_result]) |
|
|
| return inpaint_result |