| import os |
| os.environ["HF_HOME"] = "/home/wanghongbo06/.cache/huggingface" |
|
|
| import torch |
| from diffusers.pipelines import FluxPipeline |
| from src.flux.condition import Condition |
| from src.flux.generate import generate, seed_everything |
| from color_fix import wavelet_color_fix, adain_color_fix |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| |
| input_folder = "/home/wanghongbo06/diffusion-dpo-test/data_val" |
| output_folder = "./results-test/adv/lora_60" |
| os.makedirs(output_folder, exist_ok=True) |
|
|
| |
| pipe = FluxPipeline.from_pretrained( |
| "black-forest-labs/FLUX.1-dev", |
| torch_dtype=torch.bfloat16 |
| ).to("cuda") |
|
|
| |
| pipe.load_lora_weights( |
| "/home/wanghongbo06/baipurui/CKPTs/FLUX_SR/pytorch_lora_weights_v2.safetensors", |
| adapter_name="sr" |
| ) |
| pipe.fuse_lora(lora_scale=1.0, adapter_names=["sr"]) |
| pipe.unload_lora_weights() |
|
|
| |
| pipe.load_lora_weights( |
| "/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors", |
| adapter_name="sr2" |
| ) |
| pipe.fuse_lora(lora_scale=1.0, adapter_names=["sr2"]) |
| pipe.unload_lora_weights() |
|
|
| prompt = "" |
|
|
| |
| for filename in tqdm(sorted(os.listdir(input_folder))): |
| if not filename.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".webp")): |
| continue |
| |
| image_path = os.path.join(input_folder, filename) |
| image = Image.open(image_path).convert("RGB") |
|
|
| |
| w, h = image.size |
| min_dim = min(w, h) |
| image = image.crop( |
| ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2) |
| ).resize((512, 512), Image.BICUBIC) |
|
|
| |
| condition = Condition("sr", image) |
| seed_everything() |
|
|
| result_img = generate( |
| pipe, |
| prompt=prompt, |
| conditions=[condition], |
| default_lora=True, |
| ).images[0] |
|
|
| |
| result_img = adain_color_fix(result_img, image) |
|
|
| |
| result_img.save(os.path.join(output_folder, filename)) |