Update README.md and *.py files for RSEdit-UNet-text-ablation
Browse files
README.md
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RSEdit-UNet Text Encoder Ablation Models - Inference Guide
|
| 2 |
+
|
| 3 |
+
Quick guide for running inference with RSEdit UNet ablation models (text encoder variants).
|
| 4 |
+
|
| 5 |
+
## Quick Start
|
| 6 |
+
|
| 7 |
+
### Python Code Example
|
| 8 |
+
|
| 9 |
+
```python
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
|
| 13 |
+
|
| 14 |
+
# Example: DGTRS-CLIP-ViT-L-14 ablation model
|
| 15 |
+
# Each checkpoint directory is self-contained with all components
|
| 16 |
+
checkpoint_path = "/data/models/ours/BiliSakura/RSEdit-UNet-text-ablation/DGTRS-CLIP-ViT-L-14"
|
| 17 |
+
|
| 18 |
+
# Load pipeline from checkpoint (loads all components: vae, text_encoder, tokenizer, scheduler)
|
| 19 |
+
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
| 20 |
+
checkpoint_path,
|
| 21 |
+
torch_dtype=torch.bfloat16,
|
| 22 |
+
safety_checker=None,
|
| 23 |
+
requires_safety_checker=False,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Override UNet with trained EMA weights
|
| 27 |
+
pipe.unet = UNet2DConditionModel.from_pretrained(
|
| 28 |
+
f"{checkpoint_path}/checkpoint-30000/unet_ema",
|
| 29 |
+
torch_dtype=torch.bfloat16,
|
| 30 |
+
)
|
| 31 |
+
pipe = pipe.to("cuda")
|
| 32 |
+
|
| 33 |
+
# Load source image
|
| 34 |
+
source_image = Image.open("satellite_image.png").convert("RGB")
|
| 35 |
+
|
| 36 |
+
# Edit with instruction
|
| 37 |
+
prompt = "Flood the coastal area"
|
| 38 |
+
edited_image = pipe(
|
| 39 |
+
prompt=prompt,
|
| 40 |
+
image=source_image,
|
| 41 |
+
num_inference_steps=50,
|
| 42 |
+
guidance_scale=7.5,
|
| 43 |
+
image_guidance_scale=1.5,
|
| 44 |
+
).images[0]
|
| 45 |
+
|
| 46 |
+
# Save result
|
| 47 |
+
edited_image.save("edited_image.png")
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Model Structure
|
| 51 |
+
|
| 52 |
+
Each ablation model directory is self-contained and includes:
|
| 53 |
+
- `text_encoder/`: Text encoder component
|
| 54 |
+
- `tokenizer/`: Tokenizer component
|
| 55 |
+
- `vae/`: VAE component
|
| 56 |
+
- `scheduler/`: Scheduler component
|
| 57 |
+
- `unet/`: Base UNet (not used for inference)
|
| 58 |
+
- `checkpoint-30000/unet_ema/`: Trained UNet EMA weights (use for inference)
|
| 59 |
+
- `model_index.json`: Pipeline configuration
|