BoardGenerator / inference.py
CarolineM5's picture
Upload 2 files
76bee53 verified
raw
history blame
3.32 kB
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 11 09:51:38 2025
@author: camaac
"""
import PIL
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
import numpy as np
import torch.nn as nn
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
class UNetNoCondWrapper(nn.Module):
def __init__(self, base_unet: UNet2DModel):
super().__init__()
self.unet = base_unet
def forward(
self,
sample,
timestep,
encoder_hidden_states=None,
added_cond_kwargs=None,
cross_attention_kwargs=None,
return_dict=False,
**kwargs
):
return self.unet(sample, timestep, return_dict=return_dict, **kwargs)
def __getattr__(self, name):
if name in ("unet", "forward", "__getstate__", "__setstate__"):
return super().__getattr__(name)
return getattr(self.unet, name)
def save_pretrained(self, save_directory, **kwargs):
# délègue à la vraie instance UNet2DModel
return self.unet.save_pretrained(save_directory, **kwargs)
def inference(model_id,device, img1, img2):
vae = AutoencoderKL.from_pretrained(f"{model_id}/vae").to(device)
scheduler = DDPMScheduler.from_pretrained(f"{model_id}/scheduler")
tokenizer = CLIPTokenizer.from_pretrained(f"{model_id}/tokenizer")
text_encoder = CLIPTextModel.from_pretrained(f"{model_id}/text_encoder").to(device)
feature_extractor = CLIPImageProcessor.from_pretrained(f"{model_id}/feature_extractor")
# 2) Chargez votre UNet non‑conditionné et wrappez‑le
base_unet = UNet2DModel.from_pretrained(f"{model_id}/unet").to(device)
wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
# 3) Construisez la pipeline manuellement
pipe = StableDiffusionInstructPix2PixPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=wrapped_unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
pipe = pipe.to(torch.float16).to(device)
generator = torch.Generator("cuda").manual_seed(0)
img1 = img1.resize((512, 512))
img2 = img2.resize((512, 512))
img1_np = np.array(img1)
if len(img1_np.shape) > 2:
img1_np = img1_np[:, :, 0]
img2_np = np.array(img2)
if len(img2_np.shape) > 2:
img2_np = img2_np[:, :, 0]
img1_np[img1_np > 200] = 255
img1_np[img1_np <= 200] = 0
img1_np = 255-img1_np
img_np = np.stack([img1_np, img2_np, img2_np], axis=2)
image = PIL.Image.fromarray(img_np)
image = PIL.ImageOps.exif_transpose(image)
num_inference_steps = 20
image_guidance_scale = 1.9
guidance_scale = 10
edited_image = pipe(
prompt=[""] ,
image=image,
num_inference_steps=num_inference_steps,
image_guidance_scale=image_guidance_scale,
guidance_scale=guidance_scale,
generator=generator,
safety_checker=None,
num_images_per_prompt=1
).images
return edited_image