Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Wed Jun 11 09:51:38 2025 | |
| @author: camaac | |
| """ | |
| import PIL | |
| import torch | |
| from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler | |
| import numpy as np | |
| import torch.nn as nn | |
| from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor | |
| from PIL import Image | |
| import random | |
| from contextlib import nullcontext | |
| import cv2 | |
| def pil_from(x): | |
| """Return a PIL.Image given either a PIL.Image or a path string.""" | |
| if isinstance(x, str): | |
| return PIL.Image.open(x) | |
| return x | |
| def inference(pipe, fiber_imgs, ring_imgs, num_steps): | |
| """ | |
| fiber_imgs: PIL.Image or paths | |
| ring_imgs: PIL.Image or paths | |
| num_steps: int (num inference steps) | |
| returns: list of 4 PIL.Image (L mode), order [1, 4, 3, 2] | |
| """ | |
| # seed + generator | |
| seed = random.randrange(0, 2**32) | |
| torch.manual_seed(seed) | |
| generator = torch.Generator("cpu").manual_seed(seed) | |
| # sizes | |
| tile = 512 | |
| canvas_size = tile * 2 | |
| # stack channels: [fiber, ring, ring] -> H,W,3 | |
| arr_f = np.array(fiber_imgs).astype(np.uint8) | |
| arr_r = np.array(ring_imgs).astype(np.uint8) | |
| arr_in = np.stack([arr_f[:,:,0], arr_r[:,:,0], arr_r[:,:,0]], axis=2) # H,W,3 | |
| input_image = PIL.Image.fromarray(arr_in) # PIL RGB | |
| # run pipeline (use autocast consistent with device) | |
| if torch.backends.mps.is_available(): | |
| autocast_ctx = nullcontext() | |
| else: | |
| autocast_ctx = torch.autocast(torch.device("cuda").type if torch.cuda.is_available() else "cpu") | |
| with autocast_ctx: | |
| out = pipe( | |
| prompt="", # empty prompt (your model ignores prompt) | |
| image=input_image, | |
| num_inference_steps=num_steps, | |
| image_guidance_scale=1.9, | |
| guidance_scale=10.0, | |
| generator=generator, | |
| safety_checker=None, | |
| num_images_per_prompt=1, | |
| ) | |
| # out.images may be a list; take first | |
| pred = out.images[0] | |
| # ensure pred is canvas_size x canvas_size | |
| if pred.size != (canvas_size, canvas_size): | |
| pred = pred.resize((canvas_size, canvas_size), PIL.Image.BILINEAR) | |
| # split into 4 tiles in same order TL, TR, BL, BR | |
| tl = pred.crop((0, 0, tile, tile)) | |
| tr = pred.crop((tile, 0, canvas_size, tile)) | |
| bl = pred.crop((0, tile, tile, canvas_size)) | |
| br = pred.crop((tile, tile, canvas_size, canvas_size)) | |
| # close opened images to free handles | |
| # fiber_imgs.close() | |
| # ring_imgs.close() | |
| return [tl, tr, bl, br] | |