File size: 2,661 Bytes
3c8903a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
007755c
3c8903a
 
 
 
 
 
0600e9e
 
3c8903a
9299372
 
3c8903a
 
9299372
3c8903a
 
0600e9e
 
 
3c8903a
 
 
9299372
0600e9e
 
 
 
 
fe68fb3
 
0600e9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c8903a
 
0600e9e
 
 
 
4ddc0c8
9299372
3c8903a
 
5705e3e
 
3c8903a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# -*- coding: utf-8 -*-
"""

Created on Wed Jun 11 09:51:38 2025



@author: camaac

"""

import PIL
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler 
import numpy as np
import torch.nn as nn
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
from PIL import Image
import random
from contextlib import nullcontext
import cv2

def pil_from(x):
    """Return a PIL.Image given either a PIL.Image or a path string."""
    if isinstance(x, str):
        return PIL.Image.open(x)
    return x

def inference(pipe, fiber_imgs, ring_imgs, num_steps):
    """

    fiber_imgs:  PIL.Image or paths

    ring_imgs:  PIL.Image or paths

    num_steps:  int (num inference steps)



    returns: list of 4 PIL.Image (L mode), order [1, 4, 3, 2]

    """
    # seed + generator
    seed = random.randrange(0, 2**32)
    torch.manual_seed(seed)
    generator = torch.Generator("cpu").manual_seed(seed)

    # sizes
    tile = 512
    canvas_size = tile * 2 


    # stack channels: [fiber, ring, ring] -> H,W,3
    arr_f = np.array(fiber_imgs).astype(np.uint8)
    arr_r = np.array(ring_imgs).astype(np.uint8)
    arr_in = np.stack([arr_f[:,:,0], arr_r[:,:,0], arr_r[:,:,0]], axis=2)  # H,W,3

    input_image = PIL.Image.fromarray(arr_in)         # PIL RGB

    # run pipeline (use autocast consistent with device)
    if torch.backends.mps.is_available():
        autocast_ctx = nullcontext()
    else:
        autocast_ctx = torch.autocast(torch.device("cuda").type if torch.cuda.is_available() else "cpu")

    with autocast_ctx:
        out = pipe(
            prompt="",  # empty prompt (your model ignores prompt)
            image=input_image,
            num_inference_steps=num_steps,
            image_guidance_scale=1.9,
            guidance_scale=10.0,
            generator=generator,
            safety_checker=None,
            num_images_per_prompt=1,
        )
        # out.images may be a list; take first
        pred = out.images[0]

    # ensure pred is canvas_size x canvas_size
    if pred.size != (canvas_size, canvas_size):
        pred = pred.resize((canvas_size, canvas_size), PIL.Image.BILINEAR)

    # split into 4 tiles in same order TL, TR, BL, BR
    tl = pred.crop((0, 0, tile, tile))
    tr = pred.crop((tile, 0, canvas_size, tile))
    bl = pred.crop((0, tile, tile, canvas_size))
    br = pred.crop((tile, tile, canvas_size, canvas_size))
    


    # close opened images to free handles
    # fiber_imgs.close()
    # ring_imgs.close()

    return [tl, tr, bl, br]