BoardGenerator_4_Faces / inference.py
CarolineM5's picture
Upload inference.py
5705e3e verified
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 11 09:51:38 2025
@author: camaac
"""
import PIL
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
import numpy as np
import torch.nn as nn
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
from PIL import Image
import random
from contextlib import nullcontext
import cv2
def pil_from(x):
"""Return a PIL.Image given either a PIL.Image or a path string."""
if isinstance(x, str):
return PIL.Image.open(x)
return x
def inference(pipe, fiber_imgs, ring_imgs, num_steps):
"""
fiber_imgs: PIL.Image or paths
ring_imgs: PIL.Image or paths
num_steps: int (num inference steps)
returns: list of 4 PIL.Image (L mode), order [1, 4, 3, 2]
"""
# seed + generator
seed = random.randrange(0, 2**32)
torch.manual_seed(seed)
generator = torch.Generator("cpu").manual_seed(seed)
# sizes
tile = 512
canvas_size = tile * 2
# stack channels: [fiber, ring, ring] -> H,W,3
arr_f = np.array(fiber_imgs).astype(np.uint8)
arr_r = np.array(ring_imgs).astype(np.uint8)
arr_in = np.stack([arr_f[:,:,0], arr_r[:,:,0], arr_r[:,:,0]], axis=2) # H,W,3
input_image = PIL.Image.fromarray(arr_in) # PIL RGB
# run pipeline (use autocast consistent with device)
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(torch.device("cuda").type if torch.cuda.is_available() else "cpu")
with autocast_ctx:
out = pipe(
prompt="", # empty prompt (your model ignores prompt)
image=input_image,
num_inference_steps=num_steps,
image_guidance_scale=1.9,
guidance_scale=10.0,
generator=generator,
safety_checker=None,
num_images_per_prompt=1,
)
# out.images may be a list; take first
pred = out.images[0]
# ensure pred is canvas_size x canvas_size
if pred.size != (canvas_size, canvas_size):
pred = pred.resize((canvas_size, canvas_size), PIL.Image.BILINEAR)
# split into 4 tiles in same order TL, TR, BL, BR
tl = pred.crop((0, 0, tile, tile))
tr = pred.crop((tile, 0, canvas_size, tile))
bl = pred.crop((0, tile, tile, canvas_size))
br = pred.crop((tile, tile, canvas_size, canvas_size))
# close opened images to free handles
# fiber_imgs.close()
# ring_imgs.close()
return [tl, tr, bl, br]