Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Wed Jun 11 09:51:38 2025 | |
| @author: camaac | |
| """ | |
| import PIL | |
| import torch | |
| from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler | |
| import numpy as np | |
| import torch.nn as nn | |
| from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor | |
| class UNetNoCondWrapper(nn.Module): | |
| def __init__(self, base_unet: UNet2DModel): | |
| super().__init__() | |
| self.unet = base_unet | |
| def forward( | |
| self, | |
| sample, | |
| timestep, | |
| encoder_hidden_states=None, | |
| added_cond_kwargs=None, | |
| cross_attention_kwargs=None, | |
| return_dict=False, | |
| **kwargs | |
| ): | |
| return self.unet(sample, timestep, return_dict=return_dict, **kwargs) | |
| def __getattr__(self, name): | |
| if name in ("unet", "forward", "__getstate__", "__setstate__"): | |
| return super().__getattr__(name) | |
| return getattr(self.unet, name) | |
| def save_pretrained(self, save_directory, **kwargs): | |
| # délègue à la vraie instance UNet2DModel | |
| return self.unet.save_pretrained(save_directory, **kwargs) | |
| def inference(model_id,device, img1, img2): | |
| vae = AutoencoderKL.from_pretrained(f"{model_id}/vae").to(device) | |
| scheduler = DDPMScheduler.from_pretrained(f"{model_id}/scheduler") | |
| tokenizer = CLIPTokenizer.from_pretrained(f"{model_id}/tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained(f"{model_id}/text_encoder").to(device) | |
| feature_extractor = CLIPImageProcessor.from_pretrained(f"{model_id}/feature_extractor") | |
| # 2) Chargez votre UNet non‑conditionné et wrappez‑le | |
| base_unet = UNet2DModel.from_pretrained(f"{model_id}/unet").to(device) | |
| wrapped_unet = UNetNoCondWrapper(base_unet).to(device) | |
| # 3) Construisez la pipeline manuellement | |
| pipe = StableDiffusionInstructPix2PixPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=wrapped_unet, | |
| scheduler=scheduler, | |
| safety_checker=None, | |
| feature_extractor=feature_extractor, | |
| ) | |
| pipe = pipe.to(torch.float16).to(device) | |
| generator = torch.Generator("cuda").manual_seed(0) | |
| img1 = img1.resize((512, 512)) | |
| img2 = img2.resize((512, 512)) | |
| img1_np = np.array(img1) | |
| if len(img1_np.shape) > 2: | |
| img1_np = img1_np[:, :, 0] | |
| img2_np = np.array(img2) | |
| if len(img2_np.shape) > 2: | |
| img2_np = img2_np[:, :, 0] | |
| img1_np[img1_np > 200] = 255 | |
| img1_np[img1_np <= 200] = 0 | |
| img1_np = 255-img1_np | |
| img_np = np.stack([img1_np, img2_np, img2_np], axis=2) | |
| image = PIL.Image.fromarray(img_np) | |
| image = PIL.ImageOps.exif_transpose(image) | |
| num_inference_steps = 20 | |
| image_guidance_scale = 1.9 | |
| guidance_scale = 10 | |
| edited_image = pipe( | |
| prompt=[""] , | |
| image=image, | |
| num_inference_steps=num_inference_steps, | |
| image_guidance_scale=image_guidance_scale, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| safety_checker=None, | |
| num_images_per_prompt=1 | |
| ).images | |
| return edited_image | |