BoardGenerator / inference.py
OsamaAbdeljaber's picture
Upload inference.py (#2)
65e8581 verified
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 11 09:51:38 2025
@author: camaac
"""
import PIL
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
import numpy as np
import torch.nn as nn
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
from PIL import Image
import random
from contextlib import nullcontext
def pil_from(x):
"""Return a PIL.Image given either a PIL.Image or a path string."""
if isinstance(x, str):
return PIL.Image.open(x)
return x
def inference(pipe, fiber_imgs, ring_imgs, num_steps):
"""
fiber_imgs: list/tuple of 4 PIL.Image or paths (order: TL, TR, BL, BR)
ring_imgs: list/tuple of 4 PIL.Image or paths (same order)
num_steps: int (num inference steps)
returns: list of 4 PIL.Image (L mode), order [TL, TR, BL, BR]
"""
# seed + generator
seed = random.randrange(0, 2**32)
torch.manual_seed(seed)
generator = torch.Generator("cpu").manual_seed(seed)
# sizes
tile = 512
canvas_size = tile * 2 # 1024
# normalize/validate inputs: accept lists or separate args
if not (isinstance(fiber_imgs, (list, tuple)) and len(fiber_imgs) == 4):
raise ValueError("fiber_imgs must be a list/tuple of 4 PIL images or file paths.")
if not (isinstance(ring_imgs, (list, tuple)) and len(ring_imgs) == 4):
raise ValueError("ring_imgs must be a list/tuple of 4 PIL images or file paths.")
# load & preprocess each face
faces_f = []
faces_r = []
for fpath in fiber_imgs:
im = pil_from(fpath).convert("L").resize((tile, tile), PIL.Image.BILINEAR)
faces_f.append(im)
for rpath in ring_imgs:
im = pil_from(rpath).convert("L").resize((tile, tile), PIL.Image.BILINEAR)
# binarize like in your old code
arr = np.array(im)
arr[arr > 200] = 255
arr[arr <= 200] = 0
im_bin = PIL.Image.fromarray(arr.astype(np.uint8))
faces_r.append(im_bin)
# build canvases (L mode)
canvas_f = PIL.Image.new("L", (canvas_size, canvas_size))
canvas_r = PIL.Image.new("L", (canvas_size, canvas_size))
# paste into corners: order = [TL, TR, BL, BR]
canvas_f.paste(faces_f[0], (0, 0)) # TL
canvas_f.paste(faces_f[1], (tile, 0)) # TR
canvas_f.paste(faces_f[2], (0, tile)) # BL
canvas_f.paste(faces_f[3], (tile, tile)) # BR
canvas_r.paste(faces_r[0], (0, 0))
canvas_r.paste(faces_r[1], (tile, 0))
canvas_r.paste(faces_r[2], (0, tile))
canvas_r.paste(faces_r[3], (tile, tile))
# stack channels: [fiber, ring, ring] -> H,W,3
arr_f = np.array(canvas_f).astype(np.uint8)
arr_r = np.array(canvas_r).astype(np.uint8)
arr_in = np.stack([arr_f, arr_r, arr_r], axis=2) # H,W,3
input_image = PIL.Image.fromarray(arr_in) # PIL RGB
# run pipeline (use autocast consistent with device)
edited_images = []
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(torch.device("cuda").type if torch.cuda.is_available() else "cpu")
with autocast_ctx:
out = pipe(
prompt="", # empty prompt (your model ignores prompt)
image=input_image,
num_inference_steps=num_steps,
image_guidance_scale=1.9,
guidance_scale=10.0,
generator=generator,
safety_checker=None,
num_images_per_prompt=1,
)
# out.images may be a list; take first
pred = out.images[0]
# ensure pred is canvas_size x canvas_size
if pred.size != (canvas_size, canvas_size):
pred = pred.resize((canvas_size, canvas_size), PIL.Image.BILINEAR)
# split into 4 tiles in same order TL, TR, BL, BR
tl = pred.crop((0, 0, tile, tile))
tr = pred.crop((tile, 0, canvas_size, tile))
bl = pred.crop((0, tile, tile, canvas_size))
br = pred.crop((tile, tile, canvas_size, canvas_size))
# close opened images to free handles
for im in faces_f + faces_r:
try:
im.close()
except Exception:
pass
try:
canvas_f.close(); canvas_r.close()
except Exception:
pass
return [tl, tr, bl, br]