File size: 2,544 Bytes
76bee53
 
 
 
 
 
 
 
 
 
 
 
 
a9653d6
b3652cf
76bee53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00edf85
b3652cf
 
 
990a91c
b3652cf
76bee53
 
 
 
 
 
 
 
 
 
 
 
 
 
00edf85
76bee53
 
 
 
00edf85
 
170b294
51f3d5f
76bee53
 
 
 
 
 
 
 
 
 
 
 
 
 
8131593
990a91c
7402c4e
990a91c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# -*- coding: utf-8 -*-
"""

Created on Wed Jun 11 09:51:38 2025



@author: camaac

"""

import PIL
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler 
import numpy as np
import torch.nn as nn
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
from PIL import Image
import random

class UNetNoCondWrapper(nn.Module):
    def __init__(self, base_unet: UNet2DModel):
        super().__init__()
        self.unet = base_unet

    def forward(

        self,

        sample,

        timestep,

        encoder_hidden_states=None,    

        added_cond_kwargs=None,        

        cross_attention_kwargs=None,   

        return_dict=False,

        **kwargs                       

    ):
        
        return self.unet(sample, timestep, return_dict=return_dict, **kwargs)

    def __getattr__(self, name):
        if name in ("unet", "forward", "__getstate__", "__setstate__"):
            return super().__getattr__(name)
        return getattr(self.unet, name)
    
    def save_pretrained(self, save_directory, **kwargs):
        # délègue à la vraie instance UNet2DModel
        return self.unet.save_pretrained(save_directory, **kwargs)
    
def inference(pipe, img1, img2, num_steps):
    
    seed = random.randrange(0, 2**32)
    torch.manual_seed(seed)
        
    generator = torch.Generator("cpu").manual_seed(seed)   
    
    img1 = img1.resize((512, 512))
    img2 = img2.resize((512, 512))

    img1_np = np.array(img1)
    if len(img1_np.shape) > 2:
        img1_np = img1_np[:, :, 0]
            
    img2_np = np.array(img2)
    if len(img2_np.shape) > 2:
        img2_np = img2_np[:, :, 0]

    img1_np[img1_np > 200] = 255
    img1_np[img1_np <= 200] = 0
    # img1_np = 255-img1_np
    img_np = np.stack([img1_np, img2_np, img2_np], axis=2)

    image = PIL.Image.fromarray(img_np)        
    image = PIL.ImageOps.exif_transpose(image)
    
    all_images = []
    
    num_inference_steps = num_steps
    image_guidance_scale = 1.9
    guidance_scale = 10

    edited_image = pipe(
       prompt=[""] ,
       image=image,
       num_inference_steps=num_inference_steps,
       image_guidance_scale=image_guidance_scale,
       guidance_scale=guidance_scale,
       generator=generator,
       safety_checker=None,
       num_images_per_prompt=1
    ).images
    
    edited_image = edited_image[0].convert("L")
    
    return edited_image