BiliSakura commited on
Commit
b3c62ac
·
verified ·
1 Parent(s): da7b18b

Update README.md and *.py files for RSEdit-UNet-text-ablation

Browse files
Files changed (1) hide show
  1. README.md +59 -0
README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RSEdit-UNet Text Encoder Ablation Models - Inference Guide
2
+
3
+ Quick guide for running inference with RSEdit UNet ablation models (text encoder variants).
4
+
5
+ ## Quick Start
6
+
7
+ ### Python Code Example
8
+
9
+ ```python
10
+ import torch
11
+ from PIL import Image
12
+ from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
13
+
14
+ # Example: DGTRS-CLIP-ViT-L-14 ablation model
15
+ # Each checkpoint directory is self-contained with all components
16
+ checkpoint_path = "/data/models/ours/BiliSakura/RSEdit-UNet-text-ablation/DGTRS-CLIP-ViT-L-14"
17
+
18
+ # Load pipeline from checkpoint (loads all components: vae, text_encoder, tokenizer, scheduler)
19
+ pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
20
+ checkpoint_path,
21
+ torch_dtype=torch.bfloat16,
22
+ safety_checker=None,
23
+ requires_safety_checker=False,
24
+ )
25
+
26
+ # Override UNet with trained EMA weights
27
+ pipe.unet = UNet2DConditionModel.from_pretrained(
28
+ f"{checkpoint_path}/checkpoint-30000/unet_ema",
29
+ torch_dtype=torch.bfloat16,
30
+ )
31
+ pipe = pipe.to("cuda")
32
+
33
+ # Load source image
34
+ source_image = Image.open("satellite_image.png").convert("RGB")
35
+
36
+ # Edit with instruction
37
+ prompt = "Flood the coastal area"
38
+ edited_image = pipe(
39
+ prompt=prompt,
40
+ image=source_image,
41
+ num_inference_steps=50,
42
+ guidance_scale=7.5,
43
+ image_guidance_scale=1.5,
44
+ ).images[0]
45
+
46
+ # Save result
47
+ edited_image.save("edited_image.png")
48
+ ```
49
+
50
+ ## Model Structure
51
+
52
+ Each ablation model directory is self-contained and includes:
53
+ - `text_encoder/`: Text encoder component
54
+ - `tokenizer/`: Tokenizer component
55
+ - `vae/`: VAE component
56
+ - `scheduler/`: Scheduler component
57
+ - `unet/`: Base UNet (not used for inference)
58
+ - `checkpoint-30000/unet_ema/`: Trained UNet EMA weights (use for inference)
59
+ - `model_index.json`: Pipeline configuration