Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| class BasePipeline(torch.nn.Module): | |
| def __init__(self, device="cuda", torch_dtype=torch.float16): | |
| super().__init__() | |
| self.device = device | |
| self.torch_dtype = torch_dtype | |
| def preprocess_image(self, image): | |
| image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) | |
| return image | |
| def preprocess_images(self, images): | |
| return [self.preprocess_image(image) for image in images] | |
| def vae_output_to_image(self, vae_output): | |
| image = vae_output[0].cpu().float().permute(1, 2, 0).numpy() | |
| image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) | |
| return image | |
| def vae_output_to_video(self, vae_output): | |
| video = vae_output.cpu().permute(1, 2, 0).numpy() | |
| video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video] | |
| return video | |
| def merge_latents(self, value, latents, masks, scales): | |
| height, width = value.shape[-2:] | |
| weight = torch.ones_like(value) | |
| for latent, mask, scale in zip(latents, masks, scales): | |
| mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0 | |
| mask = mask.repeat(1, latent.shape[1], 1, 1) | |
| value[mask] += latent[mask] * scale | |
| weight[mask] += scale | |
| value /= weight | |
| return value | |
| def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback): | |
| noise_pred_global = inference_callback(prompt_emb_global) | |
| noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals] | |
| noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales) | |
| return noise_pred | |