# -*- coding: utf-8 -*- """ Created on Wed Jun 11 09:51:38 2025 @author: camaac """ import PIL import torch from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler import numpy as np import torch.nn as nn from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor from PIL import Image import random class UNetNoCondWrapper(nn.Module): def __init__(self, base_unet: UNet2DModel): super().__init__() self.unet = base_unet def forward( self, sample, timestep, encoder_hidden_states=None, added_cond_kwargs=None, cross_attention_kwargs=None, return_dict=False, **kwargs ): return self.unet(sample, timestep, return_dict=return_dict, **kwargs) def __getattr__(self, name): if name in ("unet", "forward", "__getstate__", "__setstate__"): return super().__getattr__(name) return getattr(self.unet, name) def save_pretrained(self, save_directory, **kwargs): # délègue à la vraie instance UNet2DModel return self.unet.save_pretrained(save_directory, **kwargs) def inference(pipe, img1, img2, num_steps): seed = random.randrange(0, 2**32) torch.manual_seed(seed) generator = torch.Generator("cpu").manual_seed(seed) img1 = img1.resize((512, 512)) img2 = img2.resize((512, 512)) img1_np = np.array(img1) if len(img1_np.shape) > 2: img1_np = img1_np[:, :, 0] img2_np = np.array(img2) if len(img2_np.shape) > 2: img2_np = img2_np[:, :, 0] img1_np[img1_np > 200] = 255 img1_np[img1_np <= 200] = 0 # img1_np = 255-img1_np img_np = np.stack([img1_np, img2_np, img2_np], axis=2) image = PIL.Image.fromarray(img_np) image = PIL.ImageOps.exif_transpose(image) all_images = [] num_inference_steps = num_steps image_guidance_scale = 1.9 guidance_scale = 10 edited_image = pipe( prompt=[""] , image=image, num_inference_steps=num_inference_steps, image_guidance_scale=image_guidance_scale, guidance_scale=guidance_scale, generator=generator, safety_checker=None, num_images_per_prompt=1 ).images edited_image = edited_image[0].convert("L") return edited_image