Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Wed Jun 11 09:51:38 2025 | |
| @author: camaac | |
| """ | |
| import PIL | |
| import torch | |
| from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler | |
| import numpy as np | |
| import torch.nn as nn | |
| from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor | |
| from PIL import Image | |
| import random | |
| from contextlib import nullcontext | |
| def pil_from(x): | |
| """Return a PIL.Image given either a PIL.Image or a path string.""" | |
| if isinstance(x, str): | |
| return PIL.Image.open(x) | |
| return x | |
| def inference(pipe, fiber_imgs, ring_imgs, num_steps): | |
| """ | |
| fiber_imgs: list/tuple of 4 PIL.Image or paths (order: TL, TR, BL, BR) | |
| ring_imgs: list/tuple of 4 PIL.Image or paths (same order) | |
| num_steps: int (num inference steps) | |
| returns: list of 4 PIL.Image (L mode), order [TL, TR, BL, BR] | |
| """ | |
| # seed + generator | |
| seed = random.randrange(0, 2**32) | |
| torch.manual_seed(seed) | |
| generator = torch.Generator("cpu").manual_seed(seed) | |
| # sizes | |
| tile = 512 | |
| canvas_size = tile * 2 # 1024 | |
| # normalize/validate inputs: accept lists or separate args | |
| if not (isinstance(fiber_imgs, (list, tuple)) and len(fiber_imgs) == 4): | |
| raise ValueError("fiber_imgs must be a list/tuple of 4 PIL images or file paths.") | |
| if not (isinstance(ring_imgs, (list, tuple)) and len(ring_imgs) == 4): | |
| raise ValueError("ring_imgs must be a list/tuple of 4 PIL images or file paths.") | |
| # load & preprocess each face | |
| faces_f = [] | |
| faces_r = [] | |
| for fpath in fiber_imgs: | |
| im = pil_from(fpath).convert("L").resize((tile, tile), PIL.Image.BILINEAR) | |
| faces_f.append(im) | |
| for rpath in ring_imgs: | |
| im = pil_from(rpath).convert("L").resize((tile, tile), PIL.Image.BILINEAR) | |
| # binarize like in your old code | |
| arr = np.array(im) | |
| arr[arr > 200] = 255 | |
| arr[arr <= 200] = 0 | |
| im_bin = PIL.Image.fromarray(arr.astype(np.uint8)) | |
| faces_r.append(im_bin) | |
| # build canvases (L mode) | |
| canvas_f = PIL.Image.new("L", (canvas_size, canvas_size)) | |
| canvas_r = PIL.Image.new("L", (canvas_size, canvas_size)) | |
| # paste into corners: order = [TL, TR, BL, BR] | |
| canvas_f.paste(faces_f[0], (0, 0)) # TL | |
| canvas_f.paste(faces_f[1], (tile, 0)) # TR | |
| canvas_f.paste(faces_f[2], (0, tile)) # BL | |
| canvas_f.paste(faces_f[3], (tile, tile)) # BR | |
| canvas_r.paste(faces_r[0], (0, 0)) | |
| canvas_r.paste(faces_r[1], (tile, 0)) | |
| canvas_r.paste(faces_r[2], (0, tile)) | |
| canvas_r.paste(faces_r[3], (tile, tile)) | |
| # stack channels: [fiber, ring, ring] -> H,W,3 | |
| arr_f = np.array(canvas_f).astype(np.uint8) | |
| arr_r = np.array(canvas_r).astype(np.uint8) | |
| arr_in = np.stack([arr_f, arr_r, arr_r], axis=2) # H,W,3 | |
| input_image = PIL.Image.fromarray(arr_in) # PIL RGB | |
| # run pipeline (use autocast consistent with device) | |
| edited_images = [] | |
| if torch.backends.mps.is_available(): | |
| autocast_ctx = nullcontext() | |
| else: | |
| autocast_ctx = torch.autocast(torch.device("cuda").type if torch.cuda.is_available() else "cpu") | |
| with autocast_ctx: | |
| out = pipe( | |
| prompt="", # empty prompt (your model ignores prompt) | |
| image=input_image, | |
| num_inference_steps=num_steps, | |
| image_guidance_scale=1.9, | |
| guidance_scale=10.0, | |
| generator=generator, | |
| safety_checker=None, | |
| num_images_per_prompt=1, | |
| ) | |
| # out.images may be a list; take first | |
| pred = out.images[0] | |
| # ensure pred is canvas_size x canvas_size | |
| if pred.size != (canvas_size, canvas_size): | |
| pred = pred.resize((canvas_size, canvas_size), PIL.Image.BILINEAR) | |
| # split into 4 tiles in same order TL, TR, BL, BR | |
| tl = pred.crop((0, 0, tile, tile)) | |
| tr = pred.crop((tile, 0, canvas_size, tile)) | |
| bl = pred.crop((0, tile, tile, canvas_size)) | |
| br = pred.crop((tile, tile, canvas_size, canvas_size)) | |
| # close opened images to free handles | |
| for im in faces_f + faces_r: | |
| try: | |
| im.close() | |
| except Exception: | |
| pass | |
| try: | |
| canvas_f.close(); canvas_r.close() | |
| except Exception: | |
| pass | |
| return [tl, tr, bl, br] | |