| | |
| | |
| | import random |
| | from collections import OrderedDict |
| |
|
| | import torch, numpy as np |
| | from PIL import Image |
| | from scepter.modules.model.registry import MODELS |
| | from scepter.modules.utils.config import Config |
| | from scepter.modules.utils.distribute import we |
| | from .registry import BaseInference, INFERENCES |
| | from .utils import ACEPlusImageProcessor |
| |
|
| | @INFERENCES.register_class() |
| | class ACEInference(BaseInference): |
| | ''' |
| | reuse the ldm code |
| | ''' |
| | def __init__(self, cfg, logger=None): |
| | super().__init__(cfg, logger) |
| | self.pipe = MODELS.build(cfg.MODEL, logger=self.logger).eval().to(we.device_id) |
| | self.image_processor = ACEPlusImageProcessor(max_seq_len=cfg.MAX_SEQ_LEN) |
| | self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for |
| | k, v in cfg.SAMPLE_ARGS.items()} |
| | self.dtype = getattr(torch, cfg.get("DTYPE", "bfloat16")) |
| | @torch.no_grad() |
| | def __call__(self, |
| | reference_image=None, |
| | edit_image=None, |
| | edit_mask=None, |
| | prompt='', |
| | edit_type=None, |
| | output_height=1024, |
| | output_width=1024, |
| | sampler='flow_euler', |
| | sample_steps=28, |
| | guide_scale=50, |
| | lora_path=None, |
| | seed=-1, |
| | repainting_scale=0, |
| | use_change=False, |
| | keep_pixels=False, |
| | keep_pixels_rate=0.8, |
| | **kwargs): |
| | |
| | if isinstance(prompt, str): |
| | prompt = [prompt] |
| | seed = seed if seed >= 0 else random.randint(0, 2 ** 24 - 1) |
| | image, mask, change_image, content_image, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask, |
| | height=output_height, width=output_width, |
| | repainting_scale=repainting_scale, |
| | keep_pixels=keep_pixels, |
| | keep_pixels_rate=keep_pixels_rate, |
| | use_change = use_change) |
| | change_image = [None] if change_image is None else [change_image.to(we.device_id)] |
| | image, mask = [image.to(we.device_id)], [mask.to(we.device_id)] |
| |
|
| | (src_image_list, src_mask_list, modify_image_list, |
| | edit_id, prompt) = [image], [mask], [change_image], [[0]], [prompt] |
| |
|
| | with torch.amp.autocast(enabled=True, dtype=self.dtype, device_type='cuda'): |
| | out_image = self.pipe( |
| | src_image_list=src_image_list, |
| | modify_image_list= modify_image_list, |
| | src_mask_list=src_mask_list, |
| | edit_id=edit_id, |
| | image=image, |
| | image_mask=mask, |
| | prompt=prompt, |
| | sampler='flow_euler', |
| | sample_steps=sample_steps, |
| | seed=seed, |
| | guide_scale=guide_scale, |
| | show_process=True, |
| | ) |
| | imgs = [x_i['reconstruct_image'].float().permute(1, 2, 0).cpu().numpy() |
| | for x_i in out_image |
| | ] |
| | imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs] |
| | edit_image = Image.fromarray((torch.clamp(image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8)) |
| | change_image = Image.fromarray((torch.clamp(change_image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8)) |
| | mask = Image.fromarray((mask[0] * 255).squeeze(0).cpu().numpy().astype(np.uint8)) |
| | return self.image_processor.postprocess(imgs[0], slice_w, out_w, out_h), edit_image, change_image, mask, seed |
| |
|