Spaces:
Sleeping
Sleeping
File size: 2,661 Bytes
3c8903a 007755c 3c8903a 0600e9e 3c8903a 9299372 3c8903a 9299372 3c8903a 0600e9e 3c8903a 9299372 0600e9e fe68fb3 0600e9e 3c8903a 0600e9e 4ddc0c8 9299372 3c8903a 5705e3e 3c8903a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 11 09:51:38 2025
@author: camaac
"""
import PIL
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
import numpy as np
import torch.nn as nn
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
from PIL import Image
import random
from contextlib import nullcontext
import cv2
def pil_from(x):
"""Return a PIL.Image given either a PIL.Image or a path string."""
if isinstance(x, str):
return PIL.Image.open(x)
return x
def inference(pipe, fiber_imgs, ring_imgs, num_steps):
"""
fiber_imgs: PIL.Image or paths
ring_imgs: PIL.Image or paths
num_steps: int (num inference steps)
returns: list of 4 PIL.Image (L mode), order [1, 4, 3, 2]
"""
# seed + generator
seed = random.randrange(0, 2**32)
torch.manual_seed(seed)
generator = torch.Generator("cpu").manual_seed(seed)
# sizes
tile = 512
canvas_size = tile * 2
# stack channels: [fiber, ring, ring] -> H,W,3
arr_f = np.array(fiber_imgs).astype(np.uint8)
arr_r = np.array(ring_imgs).astype(np.uint8)
arr_in = np.stack([arr_f[:,:,0], arr_r[:,:,0], arr_r[:,:,0]], axis=2) # H,W,3
input_image = PIL.Image.fromarray(arr_in) # PIL RGB
# run pipeline (use autocast consistent with device)
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(torch.device("cuda").type if torch.cuda.is_available() else "cpu")
with autocast_ctx:
out = pipe(
prompt="", # empty prompt (your model ignores prompt)
image=input_image,
num_inference_steps=num_steps,
image_guidance_scale=1.9,
guidance_scale=10.0,
generator=generator,
safety_checker=None,
num_images_per_prompt=1,
)
# out.images may be a list; take first
pred = out.images[0]
# ensure pred is canvas_size x canvas_size
if pred.size != (canvas_size, canvas_size):
pred = pred.resize((canvas_size, canvas_size), PIL.Image.BILINEAR)
# split into 4 tiles in same order TL, TR, BL, BR
tl = pred.crop((0, 0, tile, tile))
tr = pred.crop((tile, 0, canvas_size, tile))
bl = pred.crop((0, tile, tile, canvas_size))
br = pred.crop((tile, tile, canvas_size, canvas_size))
# close opened images to free handles
# fiber_imgs.close()
# ring_imgs.close()
return [tl, tr, bl, br]
|