Create infer_style_transfer.py
Browse files- infer_style_transfer.py +73 -0
infer_style_transfer.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
pipe = QwenImagePipeline.from_pretrained(
|
| 8 |
+
torch_dtype=torch.bfloat16,
|
| 9 |
+
device="cuda",
|
| 10 |
+
model_configs=[
|
| 11 |
+
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
| 12 |
+
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
| 13 |
+
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
| 14 |
+
],
|
| 15 |
+
tokenizer_config=None,
|
| 16 |
+
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
qwen_image_style_transfer_lora_model='./diffsynth_Qwen-Image-Edit-2509-Style-Transfer-V1.safetensors'
|
| 21 |
+
|
| 22 |
+
qwen_image_speedup_lora_model='./diffsynth_Qwen-Image-Edit-2509-Lightning-4steps-V1.0-bf16.safetensors'
|
| 23 |
+
|
| 24 |
+
pipe.load_lora(pipe.dit, qwen_image_style_transfer_lora_model)
|
| 25 |
+
pipe.load_lora(pipe.dit, qwen_image_speedup_lora_model)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
content_ref='' #content reference image
|
| 30 |
+
style_ref=''#style reference image
|
| 31 |
+
|
| 32 |
+
prompt = 'Style Transfer the style of Figure 2 to Figure 1, and keep the content and characteristics of Figure 1.'
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
w,h=Image.open(content_ref).convert("RGB").size
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
minedge=1024
|
| 42 |
+
if w>h:
|
| 43 |
+
r=w/h
|
| 44 |
+
h=minedge
|
| 45 |
+
w=int(h*r)-int(h*r)%16
|
| 46 |
+
|
| 47 |
+
else:
|
| 48 |
+
r=h/w
|
| 49 |
+
w=minedge
|
| 50 |
+
h=int(w*r)-int(w*r)%16
|
| 51 |
+
|
| 52 |
+
images = [
|
| 53 |
+
Image.open(content_ref).convert("RGB").resize((w, h)),
|
| 54 |
+
Image.open(style_ref).convert("RGB").resize((minedge, minedge)) ,
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
image = pipe(prompt, edit_image=images, seed=123, num_inference_steps=4, height=h, width=w,edit_image_auto_resize=False,cfg_scale=1.0)#ligtning
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
save_dir=f'./qwen_style_output/'
|
| 64 |
+
|
| 65 |
+
os.makedirs(save_dir,exist_ok=True)
|
| 66 |
+
prefix=style_ref.split('/')[-1].split('.')[0]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
image.save(os.path.join(save_dir, f'{prefix}_result.png'))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
print(f"saved to {os.path.join(save_dir, f'{prefix}_result.png')}")
|
| 73 |
+
|