|
|
|
|
|
"""
|
|
|
Created on Wed Jun 11 09:51:38 2025
|
|
|
|
|
|
@author: camaac
|
|
|
"""
|
|
|
|
|
|
import PIL
|
|
|
import torch
|
|
|
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
|
|
|
import numpy as np
|
|
|
import torch.nn as nn
|
|
|
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
|
|
|
from PIL import Image
|
|
|
import random
|
|
|
from contextlib import nullcontext
|
|
|
import cv2
|
|
|
|
|
|
def pil_from(x):
|
|
|
"""Return a PIL.Image given either a PIL.Image or a path string."""
|
|
|
if isinstance(x, str):
|
|
|
return PIL.Image.open(x)
|
|
|
return x
|
|
|
|
|
|
def inference(pipe, fiber_imgs, ring_imgs, num_steps):
|
|
|
"""
|
|
|
fiber_imgs: PIL.Image or paths
|
|
|
ring_imgs: PIL.Image or paths
|
|
|
num_steps: int (num inference steps)
|
|
|
|
|
|
returns: list of 4 PIL.Image (L mode), order [1, 4, 3, 2]
|
|
|
"""
|
|
|
|
|
|
seed = random.randrange(0, 2**32)
|
|
|
torch.manual_seed(seed)
|
|
|
generator = torch.Generator("cpu").manual_seed(seed)
|
|
|
|
|
|
|
|
|
tile = 512
|
|
|
canvas_size = tile * 2
|
|
|
|
|
|
|
|
|
|
|
|
arr_f = np.array(fiber_imgs).astype(np.uint8)
|
|
|
arr_r = np.array(ring_imgs).astype(np.uint8)
|
|
|
arr_in = np.stack([arr_f[:,:,0], arr_r[:,:,0], arr_r[:,:,0]], axis=2)
|
|
|
|
|
|
input_image = PIL.Image.fromarray(arr_in)
|
|
|
|
|
|
|
|
|
if torch.backends.mps.is_available():
|
|
|
autocast_ctx = nullcontext()
|
|
|
else:
|
|
|
autocast_ctx = torch.autocast(torch.device("cuda").type if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
with autocast_ctx:
|
|
|
out = pipe(
|
|
|
prompt="",
|
|
|
image=input_image,
|
|
|
num_inference_steps=num_steps,
|
|
|
image_guidance_scale=1.9,
|
|
|
guidance_scale=10.0,
|
|
|
generator=generator,
|
|
|
safety_checker=None,
|
|
|
num_images_per_prompt=1,
|
|
|
)
|
|
|
|
|
|
pred = out.images[0]
|
|
|
|
|
|
|
|
|
if pred.size != (canvas_size, canvas_size):
|
|
|
pred = pred.resize((canvas_size, canvas_size), PIL.Image.BILINEAR)
|
|
|
|
|
|
|
|
|
tl = pred.crop((0, 0, tile, tile))
|
|
|
tr = pred.crop((tile, 0, canvas_size, tile))
|
|
|
bl = pred.crop((0, tile, tile, canvas_size))
|
|
|
br = pred.crop((tile, tile, canvas_size, canvas_size))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return [tl, tr, bl, br]
|
|
|
|
|
|
|
|
|
|