BoardGenerator / inference.py
CarolineM5's picture
Upload inference.py
8131593 verified
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 11 09:51:38 2025
@author: camaac
"""
import PIL
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
import numpy as np
import torch.nn as nn
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
from PIL import Image
import random
class UNetNoCondWrapper(nn.Module):
def __init__(self, base_unet: UNet2DModel):
super().__init__()
self.unet = base_unet
def forward(
self,
sample,
timestep,
encoder_hidden_states=None,
added_cond_kwargs=None,
cross_attention_kwargs=None,
return_dict=False,
**kwargs
):
return self.unet(sample, timestep, return_dict=return_dict, **kwargs)
def __getattr__(self, name):
if name in ("unet", "forward", "__getstate__", "__setstate__"):
return super().__getattr__(name)
return getattr(self.unet, name)
def save_pretrained(self, save_directory, **kwargs):
# délègue à la vraie instance UNet2DModel
return self.unet.save_pretrained(save_directory, **kwargs)
def inference(pipe, img1, img2, num_steps):
seed = random.randrange(0, 2**32)
torch.manual_seed(seed)
generator = torch.Generator("cpu").manual_seed(seed)
img1 = img1.resize((512, 512))
img2 = img2.resize((512, 512))
img1_np = np.array(img1)
if len(img1_np.shape) > 2:
img1_np = img1_np[:, :, 0]
img2_np = np.array(img2)
if len(img2_np.shape) > 2:
img2_np = img2_np[:, :, 0]
img1_np[img1_np > 200] = 255
img1_np[img1_np <= 200] = 0
# img1_np = 255-img1_np
img_np = np.stack([img1_np, img2_np, img2_np], axis=2)
image = PIL.Image.fromarray(img_np)
image = PIL.ImageOps.exif_transpose(image)
all_images = []
num_inference_steps = num_steps
image_guidance_scale = 1.9
guidance_scale = 10
edited_image = pipe(
prompt=[""] ,
image=image,
num_inference_steps=num_inference_steps,
image_guidance_scale=image_guidance_scale,
guidance_scale=guidance_scale,
generator=generator,
safety_checker=None,
num_images_per_prompt=1
).images
edited_image = edited_image[0].convert("L")
return edited_image