# -*- coding: utf-8 -*- """ Created on Wed Jun 11 09:51:38 2025 @author: camaac """ import PIL import torch from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler import numpy as np import torch.nn as nn from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor from PIL import Image import random from contextlib import nullcontext import cv2 def pil_from(x): """Return a PIL.Image given either a PIL.Image or a path string.""" if isinstance(x, str): return PIL.Image.open(x) return x def inference(pipe, fiber_imgs, ring_imgs, num_steps): """ fiber_imgs: PIL.Image or paths ring_imgs: PIL.Image or paths num_steps: int (num inference steps) returns: list of 4 PIL.Image (L mode), order [1, 4, 3, 2] """ # seed + generator seed = random.randrange(0, 2**32) torch.manual_seed(seed) generator = torch.Generator("cpu").manual_seed(seed) # sizes tile = 512 canvas_size = tile * 2 # stack channels: [fiber, ring, ring] -> H,W,3 arr_f = np.array(fiber_imgs).astype(np.uint8) arr_r = np.array(ring_imgs).astype(np.uint8) arr_in = np.stack([arr_f[:,:,0], arr_r[:,:,0], arr_r[:,:,0]], axis=2) # H,W,3 input_image = PIL.Image.fromarray(arr_in) # PIL RGB # run pipeline (use autocast consistent with device) if torch.backends.mps.is_available(): autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(torch.device("cuda").type if torch.cuda.is_available() else "cpu") with autocast_ctx: out = pipe( prompt="", # empty prompt (your model ignores prompt) image=input_image, num_inference_steps=num_steps, image_guidance_scale=1.9, guidance_scale=10.0, generator=generator, safety_checker=None, num_images_per_prompt=1, ) # out.images may be a list; take first pred = out.images[0] # ensure pred is canvas_size x canvas_size if pred.size != (canvas_size, canvas_size): pred = pred.resize((canvas_size, canvas_size), PIL.Image.BILINEAR) # split into 4 tiles in same order TL, TR, BL, BR tl = pred.crop((0, 0, tile, tile)) tr = pred.crop((tile, 0, canvas_size, tile)) bl = pred.crop((0, tile, tile, canvas_size)) br = pred.crop((tile, tile, canvas_size, canvas_size)) # close opened images to free handles # fiber_imgs.close() # ring_imgs.close() return [tl, tr, bl, br]