Spaces:
Running
on
Zero
Running
on
Zero
Upload 61 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- adv_grpo/__pycache__/conv_gradfix.cpython-310.pyc +0 -0
- adv_grpo/__pycache__/discriminator.cpython-310.pyc +0 -0
- adv_grpo/__pycache__/ema.cpython-310.pyc +0 -0
- adv_grpo/__pycache__/grpo_discriminator.cpython-310.pyc +0 -0
- adv_grpo/__pycache__/inflated_layers.cpython-310.pyc +0 -0
- adv_grpo/__pycache__/inflated_lib.cpython-310.pyc +0 -0
- adv_grpo/__pycache__/ocr.cpython-310.pyc +0 -0
- adv_grpo/__pycache__/patchgan_discriminator.cpython-310.pyc +0 -0
- adv_grpo/__pycache__/pick_score_training.cpython-310.pyc +0 -0
- adv_grpo/__pycache__/pickscore_scorer.cpython-310.pyc +0 -0
- adv_grpo/__pycache__/prompts.cpython-310.pyc +0 -0
- adv_grpo/__pycache__/rewards.cpython-310.pyc +0 -0
- adv_grpo/__pycache__/stat_tracking.cpython-310.pyc +0 -0
- adv_grpo/__pycache__/stylegan_discriminator.cpython-310.pyc +0 -0
- adv_grpo/aesthetic_scorer.py +53 -0
- adv_grpo/assets/activities.txt +3 -0
- adv_grpo/assets/activities_v0.txt +3 -0
- adv_grpo/assets/flow_grpo_fast.png +3 -0
- adv_grpo/assets/imagenet_classes.txt +1000 -0
- adv_grpo/assets/sac+logos+ava1-l14-linearMSE.pth +3 -0
- adv_grpo/assets/simple_animals.txt +45 -0
- adv_grpo/assets/simple_ocr_animals.txt +5 -0
- adv_grpo/assets/simple_ocr_animals_digit1.txt +45 -0
- adv_grpo/assets/simple_ocr_animals_digit3.txt +45 -0
- adv_grpo/assets/simple_ocr_animals_digit5.txt +50 -0
- adv_grpo/assets/test.jpg +0 -0
- adv_grpo/clip_scorer.py +97 -0
- adv_grpo/conv_gradfix.py +345 -0
- adv_grpo/diffusers_patch/__pycache__/sd3_pipeline_with_logprob_fast.cpython-310.pyc +0 -0
- adv_grpo/diffusers_patch/__pycache__/sd3_sde_with_logprob.cpython-310.pyc +0 -0
- adv_grpo/diffusers_patch/__pycache__/train_dreambooth_lora_sd3.cpython-310.pyc +0 -0
- adv_grpo/diffusers_patch/flux_kontext_pipeline_with_logprob.py +255 -0
- adv_grpo/diffusers_patch/flux_pipeline_with_logprob.py +187 -0
- adv_grpo/diffusers_patch/sd3_pipeline_with_logprob.py +198 -0
- adv_grpo/diffusers_patch/sd3_pipeline_with_logprob_fast.py +1081 -0
- adv_grpo/diffusers_patch/sd3_sde_with_logprob.py +139 -0
- adv_grpo/diffusers_patch/train_dreambooth_lora_flux.py +144 -0
- adv_grpo/diffusers_patch/train_dreambooth_lora_sd3.py +144 -0
- adv_grpo/diffusers_patch/wan_pipeline_with_logprob.py +373 -0
- adv_grpo/diffusers_patch/wan_prompt_embedding.py +97 -0
- adv_grpo/ema.py +88 -0
- adv_grpo/imagereward_scorer.py +40 -0
- adv_grpo/inflated_layers.py +305 -0
- adv_grpo/inflated_lib.py +346 -0
- adv_grpo/ocr.py +138 -0
- adv_grpo/pick_score_training.py +385 -0
- adv_grpo/pickscore_scorer.py +70 -0
- adv_grpo/pickscore_scorer_constractive.py +89 -0
- adv_grpo/pickscore_scorer_patch.py +78 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
adv_grpo/assets/flow_grpo_fast.png filter=lfs diff=lfs merge=lfs -text
|
adv_grpo/__pycache__/conv_gradfix.cpython-310.pyc
ADDED
|
Binary file (9.28 kB). View file
|
|
|
adv_grpo/__pycache__/discriminator.cpython-310.pyc
ADDED
|
Binary file (2.95 kB). View file
|
|
|
adv_grpo/__pycache__/ema.cpython-310.pyc
ADDED
|
Binary file (3.3 kB). View file
|
|
|
adv_grpo/__pycache__/grpo_discriminator.cpython-310.pyc
ADDED
|
Binary file (625 Bytes). View file
|
|
|
adv_grpo/__pycache__/inflated_layers.cpython-310.pyc
ADDED
|
Binary file (7.45 kB). View file
|
|
|
adv_grpo/__pycache__/inflated_lib.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
adv_grpo/__pycache__/ocr.cpython-310.pyc
ADDED
|
Binary file (4.41 kB). View file
|
|
|
adv_grpo/__pycache__/patchgan_discriminator.cpython-310.pyc
ADDED
|
Binary file (5.55 kB). View file
|
|
|
adv_grpo/__pycache__/pick_score_training.cpython-310.pyc
ADDED
|
Binary file (9.12 kB). View file
|
|
|
adv_grpo/__pycache__/pickscore_scorer.cpython-310.pyc
ADDED
|
Binary file (2.44 kB). View file
|
|
|
adv_grpo/__pycache__/prompts.cpython-310.pyc
ADDED
|
Binary file (2.93 kB). View file
|
|
|
adv_grpo/__pycache__/rewards.cpython-310.pyc
ADDED
|
Binary file (29.7 kB). View file
|
|
|
adv_grpo/__pycache__/stat_tracking.cpython-310.pyc
ADDED
|
Binary file (2.62 kB). View file
|
|
|
adv_grpo/__pycache__/stylegan_discriminator.cpython-310.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
adv_grpo/aesthetic_scorer.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Based on https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/fe88a163f4661b4ddabba0751ff645e2e620746e/simple_inference.py
|
| 2 |
+
|
| 3 |
+
from importlib import resources
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transformers import CLIPModel, CLIPProcessor
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
ASSETS_PATH = resources.files("adv_grpo.assets")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MLP(nn.Module):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.layers = nn.Sequential(
|
| 17 |
+
nn.Linear(768, 1024),
|
| 18 |
+
nn.Dropout(0.2),
|
| 19 |
+
nn.Linear(1024, 128),
|
| 20 |
+
nn.Dropout(0.2),
|
| 21 |
+
nn.Linear(128, 64),
|
| 22 |
+
nn.Dropout(0.1),
|
| 23 |
+
nn.Linear(64, 16),
|
| 24 |
+
nn.Linear(16, 1),
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
@torch.no_grad()
|
| 28 |
+
def forward(self, embed):
|
| 29 |
+
return self.layers(embed)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AestheticScorer(torch.nn.Module):
|
| 33 |
+
def __init__(self, dtype):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
| 36 |
+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
| 37 |
+
self.mlp = MLP()
|
| 38 |
+
state_dict = torch.load(
|
| 39 |
+
ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")
|
| 40 |
+
)
|
| 41 |
+
self.mlp.load_state_dict(state_dict)
|
| 42 |
+
self.dtype = dtype
|
| 43 |
+
self.eval()
|
| 44 |
+
|
| 45 |
+
@torch.no_grad()
|
| 46 |
+
def __call__(self, images):
|
| 47 |
+
device = next(self.parameters()).device
|
| 48 |
+
inputs = self.processor(images=images, return_tensors="pt")
|
| 49 |
+
inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
|
| 50 |
+
embed = self.clip.get_image_features(**inputs)
|
| 51 |
+
# normalize embedding
|
| 52 |
+
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
|
| 53 |
+
return self.mlp(embed).squeeze(1)
|
adv_grpo/assets/activities.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
washing the dishes
|
| 2 |
+
riding a bike
|
| 3 |
+
playing chess
|
adv_grpo/assets/activities_v0.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
washing the dishes
|
| 2 |
+
riding a bike
|
| 3 |
+
playing chess
|
adv_grpo/assets/flow_grpo_fast.png
ADDED
|
Git LFS Details
|
adv_grpo/assets/imagenet_classes.txt
ADDED
|
@@ -0,0 +1,1000 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tench, Tinca tinca
|
| 2 |
+
goldfish, Carassius auratus
|
| 3 |
+
great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
|
| 4 |
+
tiger shark, Galeocerdo cuvieri
|
| 5 |
+
hammerhead, hammerhead shark
|
| 6 |
+
electric ray, crampfish, numbfish, torpedo
|
| 7 |
+
stingray
|
| 8 |
+
cock
|
| 9 |
+
hen
|
| 10 |
+
ostrich, Struthio camelus
|
| 11 |
+
brambling, Fringilla montifringilla
|
| 12 |
+
goldfinch, Carduelis carduelis
|
| 13 |
+
house finch, linnet, Carpodacus mexicanus
|
| 14 |
+
junco, snowbird
|
| 15 |
+
indigo bunting, indigo finch, indigo bird, Passerina cyanea
|
| 16 |
+
robin, American robin, Turdus migratorius
|
| 17 |
+
bulbul
|
| 18 |
+
jay
|
| 19 |
+
magpie
|
| 20 |
+
chickadee
|
| 21 |
+
water ouzel, dipper
|
| 22 |
+
kite
|
| 23 |
+
bald eagle, American eagle, Haliaeetus leucocephalus
|
| 24 |
+
vulture
|
| 25 |
+
great grey owl, great gray owl, Strix nebulosa
|
| 26 |
+
European fire salamander, Salamandra salamandra
|
| 27 |
+
common newt, Triturus vulgaris
|
| 28 |
+
eft
|
| 29 |
+
spotted salamander, Ambystoma maculatum
|
| 30 |
+
axolotl, mud puppy, Ambystoma mexicanum
|
| 31 |
+
bullfrog, Rana catesbeiana
|
| 32 |
+
tree frog, tree-frog
|
| 33 |
+
tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
|
| 34 |
+
loggerhead, loggerhead turtle, Caretta caretta
|
| 35 |
+
leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea
|
| 36 |
+
mud turtle
|
| 37 |
+
terrapin
|
| 38 |
+
box turtle, box tortoise
|
| 39 |
+
banded gecko
|
| 40 |
+
common iguana, iguana, Iguana iguana
|
| 41 |
+
American chameleon, anole, Anolis carolinensis
|
| 42 |
+
whiptail, whiptail lizard
|
| 43 |
+
agama
|
| 44 |
+
frilled lizard, Chlamydosaurus kingi
|
| 45 |
+
alligator lizard
|
| 46 |
+
Gila monster, Heloderma suspectum
|
| 47 |
+
green lizard, Lacerta viridis
|
| 48 |
+
African chameleon, Chamaeleo chamaeleon
|
| 49 |
+
Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis
|
| 50 |
+
African crocodile, Nile crocodile, Crocodylus niloticus
|
| 51 |
+
American alligator, Alligator mississipiensis
|
| 52 |
+
triceratops
|
| 53 |
+
thunder snake, worm snake, Carphophis amoenus
|
| 54 |
+
ringneck snake, ring-necked snake, ring snake
|
| 55 |
+
hognose snake, puff adder, sand viper
|
| 56 |
+
green snake, grass snake
|
| 57 |
+
king snake, kingsnake
|
| 58 |
+
garter snake, grass snake
|
| 59 |
+
water snake
|
| 60 |
+
vine snake
|
| 61 |
+
night snake, Hypsiglena torquata
|
| 62 |
+
boa constrictor, Constrictor constrictor
|
| 63 |
+
rock python, rock snake, Python sebae
|
| 64 |
+
Indian cobra, Naja naja
|
| 65 |
+
green mamba
|
| 66 |
+
sea snake
|
| 67 |
+
horned viper, cerastes, sand viper, horned asp, Cerastes cornutus
|
| 68 |
+
diamondback, diamondback rattlesnake, Crotalus adamanteus
|
| 69 |
+
sidewinder, horned rattlesnake, Crotalus cerastes
|
| 70 |
+
trilobite
|
| 71 |
+
harvestman, daddy longlegs, Phalangium opilio
|
| 72 |
+
scorpion
|
| 73 |
+
black and gold garden spider, Argiope aurantia
|
| 74 |
+
barn spider, Araneus cavaticus
|
| 75 |
+
garden spider, Aranea diademata
|
| 76 |
+
black widow, Latrodectus mactans
|
| 77 |
+
tarantula
|
| 78 |
+
wolf spider, hunting spider
|
| 79 |
+
tick
|
| 80 |
+
centipede
|
| 81 |
+
black grouse
|
| 82 |
+
ptarmigan
|
| 83 |
+
ruffed grouse, partridge, Bonasa umbellus
|
| 84 |
+
prairie chicken, prairie grouse, prairie fowl
|
| 85 |
+
peacock
|
| 86 |
+
quail
|
| 87 |
+
partridge
|
| 88 |
+
African grey, African gray, Psittacus erithacus
|
| 89 |
+
macaw
|
| 90 |
+
sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita
|
| 91 |
+
lorikeet
|
| 92 |
+
coucal
|
| 93 |
+
bee eater
|
| 94 |
+
hornbill
|
| 95 |
+
hummingbird
|
| 96 |
+
jacamar
|
| 97 |
+
toucan
|
| 98 |
+
drake
|
| 99 |
+
red-breasted merganser, Mergus serrator
|
| 100 |
+
goose
|
| 101 |
+
black swan, Cygnus atratus
|
| 102 |
+
tusker
|
| 103 |
+
echidna, spiny anteater, anteater
|
| 104 |
+
platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus
|
| 105 |
+
wallaby, brush kangaroo
|
| 106 |
+
koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
|
| 107 |
+
wombat
|
| 108 |
+
jellyfish
|
| 109 |
+
sea anemone, anemone
|
| 110 |
+
brain coral
|
| 111 |
+
flatworm, platyhelminth
|
| 112 |
+
nematode, nematode worm, roundworm
|
| 113 |
+
conch
|
| 114 |
+
snail
|
| 115 |
+
slug
|
| 116 |
+
sea slug, nudibranch
|
| 117 |
+
chiton, coat-of-mail shell, sea cradle, polyplacophore
|
| 118 |
+
chambered nautilus, pearly nautilus, nautilus
|
| 119 |
+
Dungeness crab, Cancer magister
|
| 120 |
+
rock crab, Cancer irroratus
|
| 121 |
+
fiddler crab
|
| 122 |
+
king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica
|
| 123 |
+
American lobster, Northern lobster, Maine lobster, Homarus americanus
|
| 124 |
+
spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
|
| 125 |
+
crayfish, crawfish, crawdad, crawdaddy
|
| 126 |
+
hermit crab
|
| 127 |
+
isopod
|
| 128 |
+
white stork, Ciconia ciconia
|
| 129 |
+
black stork, Ciconia nigra
|
| 130 |
+
spoonbill
|
| 131 |
+
flamingo
|
| 132 |
+
little blue heron, Egretta caerulea
|
| 133 |
+
American egret, great white heron, Egretta albus
|
| 134 |
+
bittern
|
| 135 |
+
crane
|
| 136 |
+
limpkin, Aramus pictus
|
| 137 |
+
European gallinule, Porphyrio porphyrio
|
| 138 |
+
American coot, marsh hen, mud hen, water hen, Fulica americana
|
| 139 |
+
bustard
|
| 140 |
+
ruddy turnstone, Arenaria interpres
|
| 141 |
+
red-backed sandpiper, dunlin, Erolia alpina
|
| 142 |
+
redshank, Tringa totanus
|
| 143 |
+
dowitcher
|
| 144 |
+
oystercatcher, oyster catcher
|
| 145 |
+
pelican
|
| 146 |
+
king penguin, Aptenodytes patagonica
|
| 147 |
+
albatross, mollymawk
|
| 148 |
+
grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus
|
| 149 |
+
killer whale, killer, orca, grampus, sea wolf, Orcinus orca
|
| 150 |
+
dugong, Dugong dugon
|
| 151 |
+
sea lion
|
| 152 |
+
Chihuahua
|
| 153 |
+
Japanese spaniel
|
| 154 |
+
Maltese dog, Maltese terrier, Maltese
|
| 155 |
+
Pekinese, Pekingese, Peke
|
| 156 |
+
Shih-Tzu
|
| 157 |
+
Blenheim spaniel
|
| 158 |
+
papillon
|
| 159 |
+
toy terrier
|
| 160 |
+
Rhodesian ridgeback
|
| 161 |
+
Afghan hound, Afghan
|
| 162 |
+
basset, basset hound
|
| 163 |
+
beagle
|
| 164 |
+
bloodhound, sleuthhound
|
| 165 |
+
bluetick
|
| 166 |
+
black-and-tan coonhound
|
| 167 |
+
Walker hound, Walker foxhound
|
| 168 |
+
English foxhound
|
| 169 |
+
redbone
|
| 170 |
+
borzoi, Russian wolfhound
|
| 171 |
+
Irish wolfhound
|
| 172 |
+
Italian greyhound
|
| 173 |
+
whippet
|
| 174 |
+
Ibizan hound, Ibizan Podenco
|
| 175 |
+
Norwegian elkhound, elkhound
|
| 176 |
+
otterhound, otter hound
|
| 177 |
+
Saluki, gazelle hound
|
| 178 |
+
Scottish deerhound, deerhound
|
| 179 |
+
Weimaraner
|
| 180 |
+
Staffordshire bullterrier, Staffordshire bull terrier
|
| 181 |
+
American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier
|
| 182 |
+
Bedlington terrier
|
| 183 |
+
Border terrier
|
| 184 |
+
Kerry blue terrier
|
| 185 |
+
Irish terrier
|
| 186 |
+
Norfolk terrier
|
| 187 |
+
Norwich terrier
|
| 188 |
+
Yorkshire terrier
|
| 189 |
+
wire-haired fox terrier
|
| 190 |
+
Lakeland terrier
|
| 191 |
+
Sealyham terrier, Sealyham
|
| 192 |
+
Airedale, Airedale terrier
|
| 193 |
+
cairn, cairn terrier
|
| 194 |
+
Australian terrier
|
| 195 |
+
Dandie Dinmont, Dandie Dinmont terrier
|
| 196 |
+
Boston bull, Boston terrier
|
| 197 |
+
miniature schnauzer
|
| 198 |
+
giant schnauzer
|
| 199 |
+
standard schnauzer
|
| 200 |
+
Scotch terrier, Scottish terrier, Scottie
|
| 201 |
+
Tibetan terrier, chrysanthemum dog
|
| 202 |
+
silky terrier, Sydney silky
|
| 203 |
+
soft-coated wheaten terrier
|
| 204 |
+
West Highland white terrier
|
| 205 |
+
Lhasa, Lhasa apso
|
| 206 |
+
flat-coated retriever
|
| 207 |
+
curly-coated retriever
|
| 208 |
+
golden retriever
|
| 209 |
+
Labrador retriever
|
| 210 |
+
Chesapeake Bay retriever
|
| 211 |
+
German short-haired pointer
|
| 212 |
+
vizsla, Hungarian pointer
|
| 213 |
+
English setter
|
| 214 |
+
Irish setter, red setter
|
| 215 |
+
Gordon setter
|
| 216 |
+
Brittany spaniel
|
| 217 |
+
clumber, clumber spaniel
|
| 218 |
+
English springer, English springer spaniel
|
| 219 |
+
Welsh springer spaniel
|
| 220 |
+
cocker spaniel, English cocker spaniel, cocker
|
| 221 |
+
Sussex spaniel
|
| 222 |
+
Irish water spaniel
|
| 223 |
+
kuvasz
|
| 224 |
+
schipperke
|
| 225 |
+
groenendael
|
| 226 |
+
malinois
|
| 227 |
+
briard
|
| 228 |
+
kelpie
|
| 229 |
+
komondor
|
| 230 |
+
Old English sheepdog, bobtail
|
| 231 |
+
Shetland sheepdog, Shetland sheep dog, Shetland
|
| 232 |
+
collie
|
| 233 |
+
Border collie
|
| 234 |
+
Bouvier des Flandres, Bouviers des Flandres
|
| 235 |
+
Rottweiler
|
| 236 |
+
German shepherd, German shepherd dog, German police dog, alsatian
|
| 237 |
+
Doberman, Doberman pinscher
|
| 238 |
+
miniature pinscher
|
| 239 |
+
Greater Swiss Mountain dog
|
| 240 |
+
Bernese mountain dog
|
| 241 |
+
Appenzeller
|
| 242 |
+
EntleBucher
|
| 243 |
+
boxer
|
| 244 |
+
bull mastiff
|
| 245 |
+
Tibetan mastiff
|
| 246 |
+
French bulldog
|
| 247 |
+
Great Dane
|
| 248 |
+
Saint Bernard, St Bernard
|
| 249 |
+
Eskimo dog, husky
|
| 250 |
+
malamute, malemute, Alaskan malamute
|
| 251 |
+
Siberian husky
|
| 252 |
+
dalmatian, coach dog, carriage dog
|
| 253 |
+
affenpinscher, monkey pinscher, monkey dog
|
| 254 |
+
basenji
|
| 255 |
+
pug, pug-dog
|
| 256 |
+
Leonberg
|
| 257 |
+
Newfoundland, Newfoundland dog
|
| 258 |
+
Great Pyrenees
|
| 259 |
+
Samoyed, Samoyede
|
| 260 |
+
Pomeranian
|
| 261 |
+
chow, chow chow
|
| 262 |
+
keeshond
|
| 263 |
+
Brabancon griffon
|
| 264 |
+
Pembroke, Pembroke Welsh corgi
|
| 265 |
+
Cardigan, Cardigan Welsh corgi
|
| 266 |
+
toy poodle
|
| 267 |
+
miniature poodle
|
| 268 |
+
standard poodle
|
| 269 |
+
Mexican hairless
|
| 270 |
+
timber wolf, grey wolf, gray wolf, Canis lupus
|
| 271 |
+
white wolf, Arctic wolf, Canis lupus tundrarum
|
| 272 |
+
red wolf, maned wolf, Canis rufus, Canis niger
|
| 273 |
+
coyote, prairie wolf, brush wolf, Canis latrans
|
| 274 |
+
dingo, warrigal, warragal, Canis dingo
|
| 275 |
+
dhole, Cuon alpinus
|
| 276 |
+
African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus
|
| 277 |
+
hyena, hyaena
|
| 278 |
+
red fox, Vulpes vulpes
|
| 279 |
+
kit fox, Vulpes macrotis
|
| 280 |
+
Arctic fox, white fox, Alopex lagopus
|
| 281 |
+
grey fox, gray fox, Urocyon cinereoargenteus
|
| 282 |
+
tabby, tabby cat
|
| 283 |
+
tiger cat
|
| 284 |
+
Persian cat
|
| 285 |
+
Siamese cat, Siamese
|
| 286 |
+
Egyptian cat
|
| 287 |
+
cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
|
| 288 |
+
lynx, catamount
|
| 289 |
+
leopard, Panthera pardus
|
| 290 |
+
snow leopard, ounce, Panthera uncia
|
| 291 |
+
jaguar, panther, Panthera onca, Felis onca
|
| 292 |
+
lion, king of beasts, Panthera leo
|
| 293 |
+
tiger, Panthera tigris
|
| 294 |
+
cheetah, chetah, Acinonyx jubatus
|
| 295 |
+
brown bear, bruin, Ursus arctos
|
| 296 |
+
American black bear, black bear, Ursus americanus, Euarctos americanus
|
| 297 |
+
ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus
|
| 298 |
+
sloth bear, Melursus ursinus, Ursus ursinus
|
| 299 |
+
mongoose
|
| 300 |
+
meerkat, mierkat
|
| 301 |
+
tiger beetle
|
| 302 |
+
ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
|
| 303 |
+
ground beetle, carabid beetle
|
| 304 |
+
long-horned beetle, longicorn, longicorn beetle
|
| 305 |
+
leaf beetle, chrysomelid
|
| 306 |
+
dung beetle
|
| 307 |
+
rhinoceros beetle
|
| 308 |
+
weevil
|
| 309 |
+
fly
|
| 310 |
+
bee
|
| 311 |
+
ant, emmet, pismire
|
| 312 |
+
grasshopper, hopper
|
| 313 |
+
cricket
|
| 314 |
+
walking stick, walkingstick, stick insect
|
| 315 |
+
cockroach, roach
|
| 316 |
+
mantis, mantid
|
| 317 |
+
cicada, cicala
|
| 318 |
+
leafhopper
|
| 319 |
+
lacewing, lacewing fly
|
| 320 |
+
dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
|
| 321 |
+
damselfly
|
| 322 |
+
admiral
|
| 323 |
+
ringlet, ringlet butterfly
|
| 324 |
+
monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
|
| 325 |
+
cabbage butterfly
|
| 326 |
+
sulphur butterfly, sulfur butterfly
|
| 327 |
+
lycaenid, lycaenid butterfly
|
| 328 |
+
starfish, sea star
|
| 329 |
+
sea urchin
|
| 330 |
+
sea cucumber, holothurian
|
| 331 |
+
wood rabbit, cottontail, cottontail rabbit
|
| 332 |
+
hare
|
| 333 |
+
Angora, Angora rabbit
|
| 334 |
+
hamster
|
| 335 |
+
porcupine, hedgehog
|
| 336 |
+
fox squirrel, eastern fox squirrel, Sciurus niger
|
| 337 |
+
marmot
|
| 338 |
+
beaver
|
| 339 |
+
guinea pig, Cavia cobaya
|
| 340 |
+
sorrel
|
| 341 |
+
zebra
|
| 342 |
+
hog, pig, grunter, squealer, Sus scrofa
|
| 343 |
+
wild boar, boar, Sus scrofa
|
| 344 |
+
warthog
|
| 345 |
+
hippopotamus, hippo, river horse, Hippopotamus amphibius
|
| 346 |
+
ox
|
| 347 |
+
water buffalo, water ox, Asiatic buffalo, Bubalus bubalis
|
| 348 |
+
bison
|
| 349 |
+
ram, tup
|
| 350 |
+
bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
|
| 351 |
+
ibex, Capra ibex
|
| 352 |
+
hartebeest
|
| 353 |
+
impala, Aepyceros melampus
|
| 354 |
+
gazelle
|
| 355 |
+
Arabian camel, dromedary, Camelus dromedarius
|
| 356 |
+
llama
|
| 357 |
+
weasel
|
| 358 |
+
mink
|
| 359 |
+
polecat, fitch, foulmart, foumart, Mustela putorius
|
| 360 |
+
black-footed ferret, ferret, Mustela nigripes
|
| 361 |
+
otter
|
| 362 |
+
skunk, polecat, wood pussy
|
| 363 |
+
badger
|
| 364 |
+
armadillo
|
| 365 |
+
three-toed sloth, ai, Bradypus tridactylus
|
| 366 |
+
orangutan, orang, orangutang, Pongo pygmaeus
|
| 367 |
+
gorilla, Gorilla gorilla
|
| 368 |
+
chimpanzee, chimp, Pan troglodytes
|
| 369 |
+
gibbon, Hylobates lar
|
| 370 |
+
siamang, Hylobates syndactylus, Symphalangus syndactylus
|
| 371 |
+
guenon, guenon monkey
|
| 372 |
+
patas, hussar monkey, Erythrocebus patas
|
| 373 |
+
baboon
|
| 374 |
+
macaque
|
| 375 |
+
langur
|
| 376 |
+
colobus, colobus monkey
|
| 377 |
+
proboscis monkey, Nasalis larvatus
|
| 378 |
+
marmoset
|
| 379 |
+
capuchin, ringtail, Cebus capucinus
|
| 380 |
+
howler monkey, howler
|
| 381 |
+
titi, titi monkey
|
| 382 |
+
spider monkey, Ateles geoffroyi
|
| 383 |
+
squirrel monkey, Saimiri sciureus
|
| 384 |
+
Madagascar cat, ring-tailed lemur, Lemur catta
|
| 385 |
+
indri, indris, Indri indri, Indri brevicaudatus
|
| 386 |
+
Indian elephant, Elephas maximus
|
| 387 |
+
African elephant, Loxodonta africana
|
| 388 |
+
lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
|
| 389 |
+
giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca
|
| 390 |
+
barracouta, snoek
|
| 391 |
+
eel
|
| 392 |
+
coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch
|
| 393 |
+
rock beauty, Holocanthus tricolor
|
| 394 |
+
anemone fish
|
| 395 |
+
sturgeon
|
| 396 |
+
gar, garfish, garpike, billfish, Lepisosteus osseus
|
| 397 |
+
lionfish
|
| 398 |
+
puffer, pufferfish, blowfish, globefish
|
| 399 |
+
abacus
|
| 400 |
+
abaya
|
| 401 |
+
academic gown, academic robe, judge's robe
|
| 402 |
+
accordion, piano accordion, squeeze box
|
| 403 |
+
acoustic guitar
|
| 404 |
+
aircraft carrier, carrier, flattop, attack aircraft carrier
|
| 405 |
+
airliner
|
| 406 |
+
airship, dirigible
|
| 407 |
+
altar
|
| 408 |
+
ambulance
|
| 409 |
+
amphibian, amphibious vehicle
|
| 410 |
+
analog clock
|
| 411 |
+
apiary, bee house
|
| 412 |
+
apron
|
| 413 |
+
ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
|
| 414 |
+
assault rifle, assault gun
|
| 415 |
+
backpack, back pack, knapsack, packsack, rucksack, haversack
|
| 416 |
+
bakery, bakeshop, bakehouse
|
| 417 |
+
balance beam, beam
|
| 418 |
+
balloon
|
| 419 |
+
ballpoint, ballpoint pen, ballpen, Biro
|
| 420 |
+
Band Aid
|
| 421 |
+
banjo
|
| 422 |
+
bannister, banister, balustrade, balusters, handrail
|
| 423 |
+
barbell
|
| 424 |
+
barber chair
|
| 425 |
+
barbershop
|
| 426 |
+
barn
|
| 427 |
+
barometer
|
| 428 |
+
barrel, cask
|
| 429 |
+
barrow, garden cart, lawn cart, wheelbarrow
|
| 430 |
+
baseball
|
| 431 |
+
basketball
|
| 432 |
+
bassinet
|
| 433 |
+
bassoon
|
| 434 |
+
bathing cap, swimming cap
|
| 435 |
+
bath towel
|
| 436 |
+
bathtub, bathing tub, bath, tub
|
| 437 |
+
beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
|
| 438 |
+
beacon, lighthouse, beacon light, pharos
|
| 439 |
+
beaker
|
| 440 |
+
bearskin, busby, shako
|
| 441 |
+
beer bottle
|
| 442 |
+
beer glass
|
| 443 |
+
bell cote, bell cot
|
| 444 |
+
bib
|
| 445 |
+
bicycle-built-for-two, tandem bicycle, tandem
|
| 446 |
+
bikini, two-piece
|
| 447 |
+
binder, ring-binder
|
| 448 |
+
binoculars, field glasses, opera glasses
|
| 449 |
+
birdhouse
|
| 450 |
+
boathouse
|
| 451 |
+
bobsled, bobsleigh, bob
|
| 452 |
+
bolo tie, bolo, bola tie, bola
|
| 453 |
+
bonnet, poke bonnet
|
| 454 |
+
bookcase
|
| 455 |
+
bookshop, bookstore, bookstall
|
| 456 |
+
bottlecap
|
| 457 |
+
bow
|
| 458 |
+
bow tie, bow-tie, bowtie
|
| 459 |
+
brass, memorial tablet, plaque
|
| 460 |
+
brassiere, bra, bandeau
|
| 461 |
+
breakwater, groin, groyne, mole, bulwark, seawall, jetty
|
| 462 |
+
breastplate, aegis, egis
|
| 463 |
+
broom
|
| 464 |
+
bucket, pail
|
| 465 |
+
buckle
|
| 466 |
+
bulletproof vest
|
| 467 |
+
bullet train, bullet
|
| 468 |
+
butcher shop, meat market
|
| 469 |
+
cab, hack, taxi, taxicab
|
| 470 |
+
caldron, cauldron
|
| 471 |
+
candle, taper, wax light
|
| 472 |
+
cannon
|
| 473 |
+
canoe
|
| 474 |
+
can opener, tin opener
|
| 475 |
+
cardigan
|
| 476 |
+
car mirror
|
| 477 |
+
carousel, carrousel, merry-go-round, roundabout, whirligig
|
| 478 |
+
carpenter's kit, tool kit
|
| 479 |
+
carton
|
| 480 |
+
car wheel
|
| 481 |
+
cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
|
| 482 |
+
cassette
|
| 483 |
+
cassette player
|
| 484 |
+
castle
|
| 485 |
+
catamaran
|
| 486 |
+
CD player
|
| 487 |
+
cello, violoncello
|
| 488 |
+
cellular telephone, cellular phone, cellphone, cell, mobile phone
|
| 489 |
+
chain
|
| 490 |
+
chainlink fence
|
| 491 |
+
chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour
|
| 492 |
+
chain saw, chainsaw
|
| 493 |
+
chest
|
| 494 |
+
chiffonier, commode
|
| 495 |
+
chime, bell, gong
|
| 496 |
+
china cabinet, china closet
|
| 497 |
+
Christmas stocking
|
| 498 |
+
church, church building
|
| 499 |
+
cinema, movie theater, movie theatre, movie house, picture palace
|
| 500 |
+
cleaver, meat cleaver, chopper
|
| 501 |
+
cliff dwelling
|
| 502 |
+
cloak
|
| 503 |
+
clog, geta, patten, sabot
|
| 504 |
+
cocktail shaker
|
| 505 |
+
coffee mug
|
| 506 |
+
coffeepot
|
| 507 |
+
coil, spiral, volute, whorl, helix
|
| 508 |
+
combination lock
|
| 509 |
+
computer keyboard, keypad
|
| 510 |
+
confectionery, confectionary, candy store
|
| 511 |
+
container ship, containership, container vessel
|
| 512 |
+
convertible
|
| 513 |
+
corkscrew, bottle screw
|
| 514 |
+
cornet, horn, trumpet, trump
|
| 515 |
+
cowboy boot
|
| 516 |
+
cowboy hat, ten-gallon hat
|
| 517 |
+
cradle
|
| 518 |
+
crane
|
| 519 |
+
crash helmet
|
| 520 |
+
crate
|
| 521 |
+
crib, cot
|
| 522 |
+
Crock Pot
|
| 523 |
+
croquet ball
|
| 524 |
+
crutch
|
| 525 |
+
cuirass
|
| 526 |
+
dam, dike, dyke
|
| 527 |
+
desk
|
| 528 |
+
desktop computer
|
| 529 |
+
dial telephone, dial phone
|
| 530 |
+
diaper, nappy, napkin
|
| 531 |
+
digital clock
|
| 532 |
+
digital watch
|
| 533 |
+
dining table, board
|
| 534 |
+
dishrag, dishcloth
|
| 535 |
+
dishwasher, dish washer, dishwashing machine
|
| 536 |
+
disk brake, disc brake
|
| 537 |
+
dock, dockage, docking facility
|
| 538 |
+
dogsled, dog sled, dog sleigh
|
| 539 |
+
dome
|
| 540 |
+
doormat, welcome mat
|
| 541 |
+
drilling platform, offshore rig
|
| 542 |
+
drum, membranophone, tympan
|
| 543 |
+
drumstick
|
| 544 |
+
dumbbell
|
| 545 |
+
Dutch oven
|
| 546 |
+
electric fan, blower
|
| 547 |
+
electric guitar
|
| 548 |
+
electric locomotive
|
| 549 |
+
entertainment center
|
| 550 |
+
envelope
|
| 551 |
+
espresso maker
|
| 552 |
+
face powder
|
| 553 |
+
feather boa, boa
|
| 554 |
+
file, file cabinet, filing cabinet
|
| 555 |
+
fireboat
|
| 556 |
+
fire engine, fire truck
|
| 557 |
+
fire screen, fireguard
|
| 558 |
+
flagpole, flagstaff
|
| 559 |
+
flute, transverse flute
|
| 560 |
+
folding chair
|
| 561 |
+
football helmet
|
| 562 |
+
forklift
|
| 563 |
+
fountain
|
| 564 |
+
fountain pen
|
| 565 |
+
four-poster
|
| 566 |
+
freight car
|
| 567 |
+
French horn, horn
|
| 568 |
+
frying pan, frypan, skillet
|
| 569 |
+
fur coat
|
| 570 |
+
garbage truck, dustcart
|
| 571 |
+
gasmask, respirator, gas helmet
|
| 572 |
+
gas pump, gasoline pump, petrol pump, island dispenser
|
| 573 |
+
goblet
|
| 574 |
+
go-kart
|
| 575 |
+
golf ball
|
| 576 |
+
golfcart, golf cart
|
| 577 |
+
gondola
|
| 578 |
+
gong, tam-tam
|
| 579 |
+
gown
|
| 580 |
+
grand piano, grand
|
| 581 |
+
greenhouse, nursery, glasshouse
|
| 582 |
+
grille, radiator grille
|
| 583 |
+
grocery store, grocery, food market, market
|
| 584 |
+
guillotine
|
| 585 |
+
hair slide
|
| 586 |
+
hair spray
|
| 587 |
+
half track
|
| 588 |
+
hammer
|
| 589 |
+
hamper
|
| 590 |
+
hand blower, blow dryer, blow drier, hair dryer, hair drier
|
| 591 |
+
hand-held computer, hand-held microcomputer
|
| 592 |
+
handkerchief, hankie, hanky, hankey
|
| 593 |
+
hard disc, hard disk, fixed disk
|
| 594 |
+
harmonica, mouth organ, harp, mouth harp
|
| 595 |
+
harp
|
| 596 |
+
harvester, reaper
|
| 597 |
+
hatchet
|
| 598 |
+
holster
|
| 599 |
+
home theater, home theatre
|
| 600 |
+
honeycomb
|
| 601 |
+
hook, claw
|
| 602 |
+
hoopskirt, crinoline
|
| 603 |
+
horizontal bar, high bar
|
| 604 |
+
horse cart, horse-cart
|
| 605 |
+
hourglass
|
| 606 |
+
iPod
|
| 607 |
+
iron, smoothing iron
|
| 608 |
+
jack-o'-lantern
|
| 609 |
+
jean, blue jean, denim
|
| 610 |
+
jeep, landrover
|
| 611 |
+
jersey, T-shirt, tee shirt
|
| 612 |
+
jigsaw puzzle
|
| 613 |
+
jinrikisha, ricksha, rickshaw
|
| 614 |
+
joystick
|
| 615 |
+
kimono
|
| 616 |
+
knee pad
|
| 617 |
+
knot
|
| 618 |
+
lab coat, laboratory coat
|
| 619 |
+
ladle
|
| 620 |
+
lampshade, lamp shade
|
| 621 |
+
laptop, laptop computer
|
| 622 |
+
lawn mower, mower
|
| 623 |
+
lens cap, lens cover
|
| 624 |
+
letter opener, paper knife, paperknife
|
| 625 |
+
library
|
| 626 |
+
lifeboat
|
| 627 |
+
lighter, light, igniter, ignitor
|
| 628 |
+
limousine, limo
|
| 629 |
+
liner, ocean liner
|
| 630 |
+
lipstick, lip rouge
|
| 631 |
+
Loafer
|
| 632 |
+
lotion
|
| 633 |
+
loudspeaker, speaker, speaker unit, loudspeaker system, speaker system
|
| 634 |
+
loupe, jeweler's loupe
|
| 635 |
+
lumbermill, sawmill
|
| 636 |
+
magnetic compass
|
| 637 |
+
mailbag, postbag
|
| 638 |
+
mailbox, letter box
|
| 639 |
+
maillot
|
| 640 |
+
maillot, tank suit
|
| 641 |
+
manhole cover
|
| 642 |
+
maraca
|
| 643 |
+
marimba, xylophone
|
| 644 |
+
mask
|
| 645 |
+
matchstick
|
| 646 |
+
maypole
|
| 647 |
+
maze, labyrinth
|
| 648 |
+
measuring cup
|
| 649 |
+
medicine chest, medicine cabinet
|
| 650 |
+
megalith, megalithic structure
|
| 651 |
+
microphone, mike
|
| 652 |
+
microwave, microwave oven
|
| 653 |
+
military uniform
|
| 654 |
+
milk can
|
| 655 |
+
minibus
|
| 656 |
+
miniskirt, mini
|
| 657 |
+
minivan
|
| 658 |
+
missile
|
| 659 |
+
mitten
|
| 660 |
+
mixing bowl
|
| 661 |
+
mobile home, manufactured home
|
| 662 |
+
Model T
|
| 663 |
+
modem
|
| 664 |
+
monastery
|
| 665 |
+
monitor
|
| 666 |
+
moped
|
| 667 |
+
mortar
|
| 668 |
+
mortarboard
|
| 669 |
+
mosque
|
| 670 |
+
mosquito net
|
| 671 |
+
motor scooter, scooter
|
| 672 |
+
mountain bike, all-terrain bike, off-roader
|
| 673 |
+
mountain tent
|
| 674 |
+
mouse, computer mouse
|
| 675 |
+
mousetrap
|
| 676 |
+
moving van
|
| 677 |
+
muzzle
|
| 678 |
+
nail
|
| 679 |
+
neck brace
|
| 680 |
+
necklace
|
| 681 |
+
nipple
|
| 682 |
+
notebook, notebook computer
|
| 683 |
+
obelisk
|
| 684 |
+
oboe, hautboy, hautbois
|
| 685 |
+
ocarina, sweet potato
|
| 686 |
+
odometer, hodometer, mileometer, milometer
|
| 687 |
+
oil filter
|
| 688 |
+
organ, pipe organ
|
| 689 |
+
oscilloscope, scope, cathode-ray oscilloscope, CRO
|
| 690 |
+
overskirt
|
| 691 |
+
oxcart
|
| 692 |
+
oxygen mask
|
| 693 |
+
packet
|
| 694 |
+
paddle, boat paddle
|
| 695 |
+
paddlewheel, paddle wheel
|
| 696 |
+
padlock
|
| 697 |
+
paintbrush
|
| 698 |
+
pajama, pyjama, pj's, jammies
|
| 699 |
+
palace
|
| 700 |
+
panpipe, pandean pipe, syrinx
|
| 701 |
+
paper towel
|
| 702 |
+
parachute, chute
|
| 703 |
+
parallel bars, bars
|
| 704 |
+
park bench
|
| 705 |
+
parking meter
|
| 706 |
+
passenger car, coach, carriage
|
| 707 |
+
patio, terrace
|
| 708 |
+
pay-phone, pay-station
|
| 709 |
+
pedestal, plinth, footstall
|
| 710 |
+
pencil box, pencil case
|
| 711 |
+
pencil sharpener
|
| 712 |
+
perfume, essence
|
| 713 |
+
Petri dish
|
| 714 |
+
photocopier
|
| 715 |
+
pick, plectrum, plectron
|
| 716 |
+
pickelhaube
|
| 717 |
+
picket fence, paling
|
| 718 |
+
pickup, pickup truck
|
| 719 |
+
pier
|
| 720 |
+
piggy bank, penny bank
|
| 721 |
+
pill bottle
|
| 722 |
+
pillow
|
| 723 |
+
ping-pong ball
|
| 724 |
+
pinwheel
|
| 725 |
+
pirate, pirate ship
|
| 726 |
+
pitcher, ewer
|
| 727 |
+
plane, carpenter's plane, woodworking plane
|
| 728 |
+
planetarium
|
| 729 |
+
plastic bag
|
| 730 |
+
plate rack
|
| 731 |
+
plow, plough
|
| 732 |
+
plunger, plumber's helper
|
| 733 |
+
Polaroid camera, Polaroid Land camera
|
| 734 |
+
pole
|
| 735 |
+
police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
|
| 736 |
+
poncho
|
| 737 |
+
pool table, billiard table, snooker table
|
| 738 |
+
pop bottle, soda bottle
|
| 739 |
+
pot, flowerpot
|
| 740 |
+
potter's wheel
|
| 741 |
+
power drill
|
| 742 |
+
prayer rug, prayer mat
|
| 743 |
+
printer
|
| 744 |
+
prison, prison house
|
| 745 |
+
projectile, missile
|
| 746 |
+
projector
|
| 747 |
+
puck, hockey puck
|
| 748 |
+
punching bag, punch bag, punching ball, punchball
|
| 749 |
+
purse
|
| 750 |
+
quill, quill pen
|
| 751 |
+
quilt, comforter, comfort, puff
|
| 752 |
+
racer, race car, racing car
|
| 753 |
+
racket, racquet
|
| 754 |
+
radiator
|
| 755 |
+
radio, wireless
|
| 756 |
+
radio telescope, radio reflector
|
| 757 |
+
rain barrel
|
| 758 |
+
recreational vehicle, RV, R.V.
|
| 759 |
+
reel
|
| 760 |
+
reflex camera
|
| 761 |
+
refrigerator, icebox
|
| 762 |
+
remote control, remote
|
| 763 |
+
restaurant, eating house, eating place, eatery
|
| 764 |
+
revolver, six-gun, six-shooter
|
| 765 |
+
rifle
|
| 766 |
+
rocking chair, rocker
|
| 767 |
+
rotisserie
|
| 768 |
+
rubber eraser, rubber, pencil eraser
|
| 769 |
+
rugby ball
|
| 770 |
+
rule, ruler
|
| 771 |
+
running shoe
|
| 772 |
+
safe
|
| 773 |
+
safety pin
|
| 774 |
+
saltshaker, salt shaker
|
| 775 |
+
sandal
|
| 776 |
+
sarong
|
| 777 |
+
sax, saxophone
|
| 778 |
+
scabbard
|
| 779 |
+
scale, weighing machine
|
| 780 |
+
school bus
|
| 781 |
+
schooner
|
| 782 |
+
scoreboard
|
| 783 |
+
screen, CRT screen
|
| 784 |
+
screw
|
| 785 |
+
screwdriver
|
| 786 |
+
seat belt, seatbelt
|
| 787 |
+
sewing machine
|
| 788 |
+
shield, buckler
|
| 789 |
+
shoe shop, shoe-shop, shoe store
|
| 790 |
+
shoji
|
| 791 |
+
shopping basket
|
| 792 |
+
shopping cart
|
| 793 |
+
shovel
|
| 794 |
+
shower cap
|
| 795 |
+
shower curtain
|
| 796 |
+
ski
|
| 797 |
+
ski mask
|
| 798 |
+
sleeping bag
|
| 799 |
+
slide rule, slipstick
|
| 800 |
+
sliding door
|
| 801 |
+
slot, one-armed bandit
|
| 802 |
+
snorkel
|
| 803 |
+
snowmobile
|
| 804 |
+
snowplow, snowplough
|
| 805 |
+
soap dispenser
|
| 806 |
+
soccer ball
|
| 807 |
+
sock
|
| 808 |
+
solar dish, solar collector, solar furnace
|
| 809 |
+
sombrero
|
| 810 |
+
soup bowl
|
| 811 |
+
space bar
|
| 812 |
+
space heater
|
| 813 |
+
space shuttle
|
| 814 |
+
spatula
|
| 815 |
+
speedboat
|
| 816 |
+
spider web, spider's web
|
| 817 |
+
spindle
|
| 818 |
+
sports car, sport car
|
| 819 |
+
spotlight, spot
|
| 820 |
+
stage
|
| 821 |
+
steam locomotive
|
| 822 |
+
steel arch bridge
|
| 823 |
+
steel drum
|
| 824 |
+
stethoscope
|
| 825 |
+
stole
|
| 826 |
+
stone wall
|
| 827 |
+
stopwatch, stop watch
|
| 828 |
+
stove
|
| 829 |
+
strainer
|
| 830 |
+
streetcar, tram, tramcar, trolley, trolley car
|
| 831 |
+
stretcher
|
| 832 |
+
studio couch, day bed
|
| 833 |
+
stupa, tope
|
| 834 |
+
submarine, pigboat, sub, U-boat
|
| 835 |
+
suit, suit of clothes
|
| 836 |
+
sundial
|
| 837 |
+
sunglass
|
| 838 |
+
sunglasses, dark glasses, shades
|
| 839 |
+
sunscreen, sunblock, sun blocker
|
| 840 |
+
suspension bridge
|
| 841 |
+
swab, swob, mop
|
| 842 |
+
sweatshirt
|
| 843 |
+
swimming trunks, bathing trunks
|
| 844 |
+
swing
|
| 845 |
+
switch, electric switch, electrical switch
|
| 846 |
+
syringe
|
| 847 |
+
table lamp
|
| 848 |
+
tank, army tank, armored combat vehicle, armoured combat vehicle
|
| 849 |
+
tape player
|
| 850 |
+
teapot
|
| 851 |
+
teddy, teddy bear
|
| 852 |
+
television, television system
|
| 853 |
+
tennis ball
|
| 854 |
+
thatch, thatched roof
|
| 855 |
+
theater curtain, theatre curtain
|
| 856 |
+
thimble
|
| 857 |
+
thresher, thrasher, threshing machine
|
| 858 |
+
throne
|
| 859 |
+
tile roof
|
| 860 |
+
toaster
|
| 861 |
+
tobacco shop, tobacconist shop, tobacconist
|
| 862 |
+
toilet seat
|
| 863 |
+
torch
|
| 864 |
+
totem pole
|
| 865 |
+
tow truck, tow car, wrecker
|
| 866 |
+
toyshop
|
| 867 |
+
tractor
|
| 868 |
+
trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi
|
| 869 |
+
tray
|
| 870 |
+
trench coat
|
| 871 |
+
tricycle, trike, velocipede
|
| 872 |
+
trimaran
|
| 873 |
+
tripod
|
| 874 |
+
triumphal arch
|
| 875 |
+
trolleybus, trolley coach, trackless trolley
|
| 876 |
+
trombone
|
| 877 |
+
tub, vat
|
| 878 |
+
turnstile
|
| 879 |
+
typewriter keyboard
|
| 880 |
+
umbrella
|
| 881 |
+
unicycle, monocycle
|
| 882 |
+
upright, upright piano
|
| 883 |
+
vacuum, vacuum cleaner
|
| 884 |
+
vase
|
| 885 |
+
vault
|
| 886 |
+
velvet
|
| 887 |
+
vending machine
|
| 888 |
+
vestment
|
| 889 |
+
viaduct
|
| 890 |
+
violin, fiddle
|
| 891 |
+
volleyball
|
| 892 |
+
waffle iron
|
| 893 |
+
wall clock
|
| 894 |
+
wallet, billfold, notecase, pocketbook
|
| 895 |
+
wardrobe, closet, press
|
| 896 |
+
warplane, military plane
|
| 897 |
+
washbasin, handbasin, washbowl, lavabo, wash-hand basin
|
| 898 |
+
washer, automatic washer, washing machine
|
| 899 |
+
water bottle
|
| 900 |
+
water jug
|
| 901 |
+
water tower
|
| 902 |
+
whiskey jug
|
| 903 |
+
whistle
|
| 904 |
+
wig
|
| 905 |
+
window screen
|
| 906 |
+
window shade
|
| 907 |
+
Windsor tie
|
| 908 |
+
wine bottle
|
| 909 |
+
wing
|
| 910 |
+
wok
|
| 911 |
+
wooden spoon
|
| 912 |
+
wool, woolen, woollen
|
| 913 |
+
worm fence, snake fence, snake-rail fence, Virginia fence
|
| 914 |
+
wreck
|
| 915 |
+
yawl
|
| 916 |
+
yurt
|
| 917 |
+
web site, website, internet site, site
|
| 918 |
+
comic book
|
| 919 |
+
crossword puzzle, crossword
|
| 920 |
+
street sign
|
| 921 |
+
traffic light, traffic signal, stoplight
|
| 922 |
+
book jacket, dust cover, dust jacket, dust wrapper
|
| 923 |
+
menu
|
| 924 |
+
plate
|
| 925 |
+
guacamole
|
| 926 |
+
consomme
|
| 927 |
+
hot pot, hotpot
|
| 928 |
+
trifle
|
| 929 |
+
ice cream, icecream
|
| 930 |
+
ice lolly, lolly, lollipop, popsicle
|
| 931 |
+
French loaf
|
| 932 |
+
bagel, beigel
|
| 933 |
+
pretzel
|
| 934 |
+
cheeseburger
|
| 935 |
+
hotdog, hot dog, red hot
|
| 936 |
+
mashed potato
|
| 937 |
+
head cabbage
|
| 938 |
+
broccoli
|
| 939 |
+
cauliflower
|
| 940 |
+
zucchini, courgette
|
| 941 |
+
spaghetti squash
|
| 942 |
+
acorn squash
|
| 943 |
+
butternut squash
|
| 944 |
+
cucumber, cuke
|
| 945 |
+
artichoke, globe artichoke
|
| 946 |
+
bell pepper
|
| 947 |
+
cardoon
|
| 948 |
+
mushroom
|
| 949 |
+
Granny Smith
|
| 950 |
+
strawberry
|
| 951 |
+
orange
|
| 952 |
+
lemon
|
| 953 |
+
fig
|
| 954 |
+
pineapple, ananas
|
| 955 |
+
banana
|
| 956 |
+
jackfruit, jak, jack
|
| 957 |
+
custard apple
|
| 958 |
+
pomegranate
|
| 959 |
+
hay
|
| 960 |
+
carbonara
|
| 961 |
+
chocolate sauce, chocolate syrup
|
| 962 |
+
dough
|
| 963 |
+
meat loaf, meatloaf
|
| 964 |
+
pizza, pizza pie
|
| 965 |
+
potpie
|
| 966 |
+
burrito
|
| 967 |
+
red wine
|
| 968 |
+
espresso
|
| 969 |
+
cup
|
| 970 |
+
eggnog
|
| 971 |
+
alp
|
| 972 |
+
bubble
|
| 973 |
+
cliff, drop, drop-off
|
| 974 |
+
coral reef
|
| 975 |
+
geyser
|
| 976 |
+
lakeside, lakeshore
|
| 977 |
+
promontory, headland, head, foreland
|
| 978 |
+
sandbar, sand bar
|
| 979 |
+
seashore, coast, seacoast, sea-coast
|
| 980 |
+
valley, vale
|
| 981 |
+
volcano
|
| 982 |
+
ballplayer, baseball player
|
| 983 |
+
groom, bridegroom
|
| 984 |
+
scuba diver
|
| 985 |
+
rapeseed
|
| 986 |
+
daisy
|
| 987 |
+
yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum
|
| 988 |
+
corn
|
| 989 |
+
acorn
|
| 990 |
+
hip, rose hip, rosehip
|
| 991 |
+
buckeye, horse chestnut, conker
|
| 992 |
+
coral fungus
|
| 993 |
+
agaric
|
| 994 |
+
gyromitra
|
| 995 |
+
stinkhorn, carrion fungus
|
| 996 |
+
earthstar
|
| 997 |
+
hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa
|
| 998 |
+
bolete
|
| 999 |
+
ear, spike, capitulum
|
| 1000 |
+
toilet tissue, toilet paper, bathroom tissue
|
adv_grpo/assets/sac+logos+ava1-l14-linearMSE.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:21dd590f3ccdc646f0d53120778b296013b096a035a2718c9cb0d511bff0f1e0
|
| 3 |
+
size 3714759
|
adv_grpo/assets/simple_animals.txt
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cat
|
| 2 |
+
dog
|
| 3 |
+
horse
|
| 4 |
+
monkey
|
| 5 |
+
rabbit
|
| 6 |
+
zebra
|
| 7 |
+
spider
|
| 8 |
+
bird
|
| 9 |
+
sheep
|
| 10 |
+
deer
|
| 11 |
+
cow
|
| 12 |
+
goat
|
| 13 |
+
lion
|
| 14 |
+
tiger
|
| 15 |
+
bear
|
| 16 |
+
raccoon
|
| 17 |
+
fox
|
| 18 |
+
wolf
|
| 19 |
+
lizard
|
| 20 |
+
beetle
|
| 21 |
+
ant
|
| 22 |
+
butterfly
|
| 23 |
+
fish
|
| 24 |
+
shark
|
| 25 |
+
whale
|
| 26 |
+
dolphin
|
| 27 |
+
squirrel
|
| 28 |
+
mouse
|
| 29 |
+
rat
|
| 30 |
+
snake
|
| 31 |
+
turtle
|
| 32 |
+
frog
|
| 33 |
+
chicken
|
| 34 |
+
duck
|
| 35 |
+
goose
|
| 36 |
+
bee
|
| 37 |
+
pig
|
| 38 |
+
turkey
|
| 39 |
+
fly
|
| 40 |
+
llama
|
| 41 |
+
camel
|
| 42 |
+
bat
|
| 43 |
+
gorilla
|
| 44 |
+
hedgehog
|
| 45 |
+
kangaroo
|
adv_grpo/assets/simple_ocr_animals.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cat
|
| 2 |
+
dog
|
| 3 |
+
horse
|
| 4 |
+
monkey
|
| 5 |
+
rabbit
|
adv_grpo/assets/simple_ocr_animals_digit1.txt
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
A cat holding a sign that says '0'
|
| 2 |
+
A dog holding a sign that says '0'
|
| 3 |
+
A horse holding a sign that says '0'
|
| 4 |
+
A monkey holding a sign that says '0'
|
| 5 |
+
A rabbit holding a sign that says '0'
|
| 6 |
+
A cat holding a sign that says '1'
|
| 7 |
+
A dog holding a sign that says '1'
|
| 8 |
+
A horse holding a sign that says '1'
|
| 9 |
+
A monkey holding a sign that says '1'
|
| 10 |
+
A rabbit holding a sign that says '1'
|
| 11 |
+
A cat holding a sign that says '2'
|
| 12 |
+
A dog holding a sign that says '2'
|
| 13 |
+
A horse holding a sign that says '2'
|
| 14 |
+
A monkey holding a sign that says '2'
|
| 15 |
+
A rabbit holding a sign that says '2'
|
| 16 |
+
A cat holding a sign that says '3'
|
| 17 |
+
A dog holding a sign that says '3'
|
| 18 |
+
A horse holding a sign that says '3'
|
| 19 |
+
A monkey holding a sign that says '3'
|
| 20 |
+
A rabbit holding a sign that says '3'
|
| 21 |
+
A cat holding a sign that says '4'
|
| 22 |
+
A dog holding a sign that says '4'
|
| 23 |
+
A horse holding a sign that says '4'
|
| 24 |
+
A monkey holding a sign that says '4'
|
| 25 |
+
A rabbit holding a sign that says '4'
|
| 26 |
+
A cat holding a sign that says '5'
|
| 27 |
+
A dog holding a sign that says '5'
|
| 28 |
+
A horse holding a sign that says '5'
|
| 29 |
+
A monkey holding a sign that says '5'
|
| 30 |
+
A rabbit holding a sign that says '5'
|
| 31 |
+
A cat holding a sign that says '6'
|
| 32 |
+
A dog holding a sign that says '6'
|
| 33 |
+
A horse holding a sign that says '6'
|
| 34 |
+
A monkey holding a sign that says '6'
|
| 35 |
+
A rabbit holding a sign that says '6'
|
| 36 |
+
A cat holding a sign that says '7'
|
| 37 |
+
A dog holding a sign that says '7'
|
| 38 |
+
A horse holding a sign that says '7'
|
| 39 |
+
A monkey holding a sign that says '7'
|
| 40 |
+
A rabbit holding a sign that says '7'
|
| 41 |
+
A cat holding a sign that says '8'
|
| 42 |
+
A dog holding a sign that says '8'
|
| 43 |
+
A horse holding a sign that says '8'
|
| 44 |
+
A monkey holding a sign that says '8'
|
| 45 |
+
A rabbit holding a sign that says '8'
|
adv_grpo/assets/simple_ocr_animals_digit3.txt
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
A cat holding a sign that says '123'
|
| 2 |
+
A dog holding a sign that says '234'
|
| 3 |
+
A horse holding a sign that says '345'
|
| 4 |
+
A monkey holding a sign that says '456'
|
| 5 |
+
A rabbit holding a sign that says '567'
|
| 6 |
+
A cat holding a sign that says '678'
|
| 7 |
+
A dog holding a sign that says '789'
|
| 8 |
+
A horse holding a sign that says '123'
|
| 9 |
+
A monkey holding a sign that says '234'
|
| 10 |
+
A rabbit holding a sign that says '345'
|
| 11 |
+
A cat holding a sign that says '456'
|
| 12 |
+
A dog holding a sign that says '567'
|
| 13 |
+
A horse holding a sign that says '678'
|
| 14 |
+
A monkey holding a sign that says '789'
|
| 15 |
+
A rabbit holding a sign that says '123'
|
| 16 |
+
A cat holding a sign that says '234'
|
| 17 |
+
A dog holding a sign that says '345'
|
| 18 |
+
A horse holding a sign that says '456'
|
| 19 |
+
A monkey holding a sign that says '567'
|
| 20 |
+
A rabbit holding a sign that says '678'
|
| 21 |
+
A cat holding a sign that says '789'
|
| 22 |
+
A dog holding a sign that says '123'
|
| 23 |
+
A horse holding a sign that says '234'
|
| 24 |
+
A monkey holding a sign that says '345'
|
| 25 |
+
A rabbit holding a sign that says '456'
|
| 26 |
+
A cat holding a sign that says '567'
|
| 27 |
+
A dog holding a sign that says '678'
|
| 28 |
+
A horse holding a sign that says '789'
|
| 29 |
+
A monkey holding a sign that says '123'
|
| 30 |
+
A rabbit holding a sign that says '234'
|
| 31 |
+
A cat holding a sign that says '345'
|
| 32 |
+
A dog holding a sign that says '456'
|
| 33 |
+
A horse holding a sign that says '567'
|
| 34 |
+
A monkey holding a sign that says '678'
|
| 35 |
+
A rabbit holding a sign that says '789'
|
| 36 |
+
A cat holding a sign that says '123'
|
| 37 |
+
A dog holding a sign that says '234'
|
| 38 |
+
A horse holding a sign that says '345'
|
| 39 |
+
A monkey holding a sign that says '456'
|
| 40 |
+
A rabbit holding a sign that says '567'
|
| 41 |
+
A cat holding a sign that says '678'
|
| 42 |
+
A dog holding a sign that says '789'
|
| 43 |
+
A horse holding a sign that says '123'
|
| 44 |
+
A monkey holding a sign that says '234'
|
| 45 |
+
A rabbit holding a sign that says '345'
|
adv_grpo/assets/simple_ocr_animals_digit5.txt
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
A cat holding a sign that says '12345'
|
| 2 |
+
A dog holding a sign that says '23456'
|
| 3 |
+
A horse holding a sign that says '34567'
|
| 4 |
+
A monkey holding a sign that says '45678'
|
| 5 |
+
A rabbit holding a sign that says '56789'
|
| 6 |
+
A cat holding a sign that says '54321'
|
| 7 |
+
A dog holding a sign that says '65432'
|
| 8 |
+
A horse holding a sign that says '76543'
|
| 9 |
+
A monkey holding a sign that says '87654'
|
| 10 |
+
A rabbit holding a sign that says '98765'
|
| 11 |
+
A cat holding a sign that says '12345'
|
| 12 |
+
A dog holding a sign that says '23456'
|
| 13 |
+
A horse holding a sign that says '34567'
|
| 14 |
+
A monkey holding a sign that says '45678'
|
| 15 |
+
A rabbit holding a sign that says '56789'
|
| 16 |
+
A cat holding a sign that says '54321'
|
| 17 |
+
A dog holding a sign that says '65432'
|
| 18 |
+
A horse holding a sign that says '76543'
|
| 19 |
+
A monkey holding a sign that says '87654'
|
| 20 |
+
A rabbit holding a sign that says '98765'
|
| 21 |
+
A cat holding a sign that says '12345'
|
| 22 |
+
A dog holding a sign that says '23456'
|
| 23 |
+
A horse holding a sign that says '34567'
|
| 24 |
+
A monkey holding a sign that says '45678'
|
| 25 |
+
A rabbit holding a sign that says '56789'
|
| 26 |
+
A cat holding a sign that says '54321'
|
| 27 |
+
A dog holding a sign that says '65432'
|
| 28 |
+
A horse holding a sign that says '76543'
|
| 29 |
+
A monkey holding a sign that says '87654'
|
| 30 |
+
A rabbit holding a sign that says '98765'
|
| 31 |
+
A cat holding a sign that says '12345'
|
| 32 |
+
A dog holding a sign that says '23456'
|
| 33 |
+
A horse holding a sign that says '34567'
|
| 34 |
+
A monkey holding a sign that says '45678'
|
| 35 |
+
A rabbit holding a sign that says '56789'
|
| 36 |
+
A cat holding a sign that says '54321'
|
| 37 |
+
A dog holding a sign that says '65432'
|
| 38 |
+
A horse holding a sign that says '76543'
|
| 39 |
+
A monkey holding a sign that says '87654'
|
| 40 |
+
A rabbit holding a sign that says '98765'
|
| 41 |
+
A cat holding a sign that says '12345'
|
| 42 |
+
A dog holding a sign that says '23456'
|
| 43 |
+
A horse holding a sign that says '34567'
|
| 44 |
+
A monkey holding a sign that says '45678'
|
| 45 |
+
A rabbit holding a sign that says '56789'
|
| 46 |
+
A cat holding a sign that says '54321'
|
| 47 |
+
A dog holding a sign that says '65432'
|
| 48 |
+
A horse holding a sign that says '76543'
|
| 49 |
+
A monkey holding a sign that says '87654'
|
| 50 |
+
A rabbit holding a sign that says '98765'
|
adv_grpo/assets/test.jpg
ADDED
|
adv_grpo/clip_scorer.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Based on https://github.com/RE-N-Y/imscore/blob/main/src/imscore/preference/model.py
|
| 2 |
+
|
| 3 |
+
from importlib import resources
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision
|
| 7 |
+
import torchvision.transforms as T
|
| 8 |
+
from transformers import AutoImageProcessor,CLIPProcessor, CLIPModel
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
def get_size(size):
|
| 13 |
+
if isinstance(size, int):
|
| 14 |
+
return (size, size)
|
| 15 |
+
elif "height" in size and "width" in size:
|
| 16 |
+
return (size["height"], size["width"])
|
| 17 |
+
elif "shortest_edge" in size:
|
| 18 |
+
return size["shortest_edge"]
|
| 19 |
+
else:
|
| 20 |
+
raise ValueError(f"Invalid size: {size}")
|
| 21 |
+
|
| 22 |
+
def get_image_transform(processor:AutoImageProcessor):
|
| 23 |
+
config = processor.to_dict()
|
| 24 |
+
resize = T.Resize(get_size(config.get("size"))) if config.get("do_resize") else nn.Identity()
|
| 25 |
+
crop = T.CenterCrop(get_size(config.get("crop_size"))) if config.get("do_center_crop") else nn.Identity()
|
| 26 |
+
normalise = T.Normalize(mean=processor.image_mean, std=processor.image_std) if config.get("do_normalize") else nn.Identity()
|
| 27 |
+
|
| 28 |
+
return T.Compose([resize, crop, normalise])
|
| 29 |
+
|
| 30 |
+
class ClipScorer(torch.nn.Module):
|
| 31 |
+
def __init__(self):
|
| 32 |
+
super().__init__()
|
| 33 |
+
# self.device="cuda"
|
| 34 |
+
self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
| 35 |
+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
| 36 |
+
self.tform = get_image_transform(self.processor.image_processor)
|
| 37 |
+
self.eval()
|
| 38 |
+
|
| 39 |
+
def _process(self, pixels):
|
| 40 |
+
dtype = pixels.dtype
|
| 41 |
+
pixels = self.tform(pixels)
|
| 42 |
+
pixels = pixels.to(dtype=dtype)
|
| 43 |
+
|
| 44 |
+
return pixels
|
| 45 |
+
|
| 46 |
+
@torch.no_grad()
|
| 47 |
+
def __call__(self, pixels, prompts, return_img_embedding=False):
|
| 48 |
+
device = next(self.parameters()).device
|
| 49 |
+
texts = self.processor(text=prompts, padding='max_length', truncation=True, return_tensors="pt").to(device)
|
| 50 |
+
pixels = self._process(pixels).to(device)
|
| 51 |
+
outputs = self.model(pixel_values=pixels, **texts)
|
| 52 |
+
if return_img_embedding:
|
| 53 |
+
return outputs.logits_per_image.diagonal()/30, outputs.image_embeds
|
| 54 |
+
return outputs.logits_per_image.diagonal()/30
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def image_similarity(self, pixels, ref_pixels):
|
| 58 |
+
device = next(self.parameters()).device
|
| 59 |
+
pixels = self._process(pixels).to(device)
|
| 60 |
+
ref_pixels = self._process(ref_pixels).to(device)
|
| 61 |
+
|
| 62 |
+
pixel_embeds = self.model.get_image_features(pixel_values=pixels)
|
| 63 |
+
ref_embeds = self.model.get_image_features(pixel_values=ref_pixels)
|
| 64 |
+
|
| 65 |
+
pixel_embeds = pixel_embeds / pixel_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 66 |
+
ref_embeds = ref_embeds / ref_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 67 |
+
|
| 68 |
+
sim = pixel_embeds @ ref_embeds.T
|
| 69 |
+
# sim = torch.diagonal(sim, 0)
|
| 70 |
+
sim = sim.squeeze(-1)
|
| 71 |
+
return sim
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def main():
|
| 75 |
+
# scorer = ClipScorer(
|
| 76 |
+
# device='cuda'
|
| 77 |
+
# )
|
| 78 |
+
scorer = ClipScorer(
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
images=[
|
| 82 |
+
"assets/test.jpg",
|
| 83 |
+
"assets/test.jpg"
|
| 84 |
+
]
|
| 85 |
+
pil_images = [Image.open(img) for img in images]
|
| 86 |
+
prompts=[
|
| 87 |
+
'an image of cat',
|
| 88 |
+
'not an image of cat'
|
| 89 |
+
]
|
| 90 |
+
images = [np.array(img) for img in pil_images]
|
| 91 |
+
images = np.array(images)
|
| 92 |
+
images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW
|
| 93 |
+
images = torch.tensor(images, dtype=torch.uint8)/255.0
|
| 94 |
+
print(scorer(images, prompts))
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
main()
|
adv_grpo/conv_gradfix.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom replacement for `torch.nn.functional.convNd` and `torch.nn.functional.conv_transposeNd`
|
| 3 |
+
that supports arbitrarily high order gradients with zero performance penalty.
|
| 4 |
+
Modified from https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/conv2d_gradfix.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import contextlib
|
| 8 |
+
import warnings
|
| 9 |
+
from typing import Optional
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from torch.nn import Conv2d, Conv3d
|
| 14 |
+
|
| 15 |
+
# pylint: disable=redefined-builtin
|
| 16 |
+
# pylint: disable=arguments-differ
|
| 17 |
+
# pylint: disable=protected-access
|
| 18 |
+
|
| 19 |
+
# ----------------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
enabled = False # Enable the custom op by setting this to true.
|
| 22 |
+
weight_gradients_disabled = (
|
| 23 |
+
False # Forcefully disable computation of gradients with respect to the weights.
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@contextlib.contextmanager
|
| 28 |
+
def no_weight_gradients():
|
| 29 |
+
global weight_gradients_disabled
|
| 30 |
+
old = weight_gradients_disabled
|
| 31 |
+
weight_gradients_disabled = True
|
| 32 |
+
yield
|
| 33 |
+
weight_gradients_disabled = old
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ----------------------------------------------------------------------------
|
| 37 |
+
class GradFixConv2d(Conv2d):
|
| 38 |
+
def __init__(self, *args, use_gradfix: bool = False, **kwargs):
|
| 39 |
+
self.use_gradfix = use_gradfix
|
| 40 |
+
super().__init__(*args, **kwargs)
|
| 41 |
+
|
| 42 |
+
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
|
| 43 |
+
conv_fn = F.conv2d if not self.use_gradfix else convNd
|
| 44 |
+
if self.padding_mode != "zeros":
|
| 45 |
+
return conv_fn(
|
| 46 |
+
F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
| 47 |
+
weight,
|
| 48 |
+
bias,
|
| 49 |
+
self.stride,
|
| 50 |
+
(0, 0),
|
| 51 |
+
self.dilation,
|
| 52 |
+
self.groups,
|
| 53 |
+
)
|
| 54 |
+
return conv_fn(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
| 55 |
+
|
| 56 |
+
def forward(
|
| 57 |
+
self, input: Tensor, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None
|
| 58 |
+
) -> Tensor:
|
| 59 |
+
weight = self.weight if weight is None else weight
|
| 60 |
+
bias = self.bias if bias is None else bias
|
| 61 |
+
return self._conv_forward(input, weight, bias)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class GradFixConv3d(Conv3d):
|
| 65 |
+
def __init__(self, *args, use_gradfix: bool = False, **kwargs):
|
| 66 |
+
self.use_gradfix = use_gradfix
|
| 67 |
+
super().__init__(*args, **kwargs)
|
| 68 |
+
|
| 69 |
+
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
|
| 70 |
+
conv_fn = F.conv3d if not self.use_gradfix else convNd
|
| 71 |
+
if self.padding_mode != "zeros":
|
| 72 |
+
return conv_fn(
|
| 73 |
+
F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
| 74 |
+
weight,
|
| 75 |
+
bias,
|
| 76 |
+
self.stride,
|
| 77 |
+
(0, 0, 0),
|
| 78 |
+
self.dilation,
|
| 79 |
+
self.groups,
|
| 80 |
+
)
|
| 81 |
+
return conv_fn(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
| 82 |
+
|
| 83 |
+
def forward(
|
| 84 |
+
self, input: Tensor, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None
|
| 85 |
+
) -> Tensor:
|
| 86 |
+
weight = self.weight if weight is None else weight
|
| 87 |
+
bias = self.bias if bias is None else bias
|
| 88 |
+
return self._conv_forward(input, weight, bias)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ----------------------------------------------------------------------------
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def convNd(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
| 95 |
+
N = weight.ndim - 2
|
| 96 |
+
if _should_use_custom_op(input):
|
| 97 |
+
return _conv_gradfix(
|
| 98 |
+
transpose=False,
|
| 99 |
+
weight_shape=weight.shape,
|
| 100 |
+
stride=stride,
|
| 101 |
+
padding=padding,
|
| 102 |
+
output_padding=0,
|
| 103 |
+
dilation=dilation,
|
| 104 |
+
groups=groups,
|
| 105 |
+
).apply(input, weight, bias)
|
| 106 |
+
return getattr(torch.nn.functional, f"conv{N}d")(
|
| 107 |
+
input=input,
|
| 108 |
+
weight=weight,
|
| 109 |
+
bias=bias,
|
| 110 |
+
stride=stride,
|
| 111 |
+
padding=padding,
|
| 112 |
+
dilation=dilation,
|
| 113 |
+
groups=groups,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def conv_transposeNd(
|
| 118 |
+
input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1
|
| 119 |
+
):
|
| 120 |
+
N = weight.ndim - 2
|
| 121 |
+
if _should_use_custom_op(input):
|
| 122 |
+
return _conv_gradfix(
|
| 123 |
+
transpose=True,
|
| 124 |
+
weight_shape=weight.shape,
|
| 125 |
+
stride=stride,
|
| 126 |
+
padding=padding,
|
| 127 |
+
output_padding=output_padding,
|
| 128 |
+
groups=groups,
|
| 129 |
+
dilation=dilation,
|
| 130 |
+
).apply(input, weight, bias)
|
| 131 |
+
return getattr(torch.nn.functional, f"conv_transpose{N}d")(
|
| 132 |
+
input=input,
|
| 133 |
+
weight=weight,
|
| 134 |
+
bias=bias,
|
| 135 |
+
stride=stride,
|
| 136 |
+
padding=padding,
|
| 137 |
+
output_padding=output_padding,
|
| 138 |
+
groups=groups,
|
| 139 |
+
dilation=dilation,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ----------------------------------------------------------------------------
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _should_use_custom_op(input):
|
| 147 |
+
assert isinstance(input, torch.Tensor)
|
| 148 |
+
if (not enabled) or (not torch.backends.cudnn.enabled):
|
| 149 |
+
return False
|
| 150 |
+
if input.device.type != "cuda":
|
| 151 |
+
return False
|
| 152 |
+
if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8.", "1.9"]):
|
| 153 |
+
return True
|
| 154 |
+
if torch.__version__.startswith("2"):
|
| 155 |
+
return True
|
| 156 |
+
warnings.warn(
|
| 157 |
+
f"conv2d_gradfix not supported on PyTorch {torch.__version__}. "
|
| 158 |
+
f"Falling back to torch.nn.functional.conv2d()."
|
| 159 |
+
)
|
| 160 |
+
return False
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _tuple_of_ints(xs, ndim):
|
| 164 |
+
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
| 165 |
+
assert len(xs) == ndim
|
| 166 |
+
assert all(isinstance(x, int) for x in xs)
|
| 167 |
+
return xs
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ----------------------------------------------------------------------------
|
| 171 |
+
|
| 172 |
+
_conv_gradfix_cache = dict()
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _conv_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
|
| 176 |
+
ndim = len(weight_shape) - 2
|
| 177 |
+
# Parse arguments.
|
| 178 |
+
weight_shape = tuple(weight_shape)
|
| 179 |
+
stride = _tuple_of_ints(stride, ndim)
|
| 180 |
+
padding = _tuple_of_ints(padding, ndim)
|
| 181 |
+
output_padding = _tuple_of_ints(output_padding, ndim)
|
| 182 |
+
dilation = _tuple_of_ints(dilation, ndim)
|
| 183 |
+
|
| 184 |
+
# Lookup from cache.
|
| 185 |
+
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
| 186 |
+
if key in _conv_gradfix_cache:
|
| 187 |
+
return _conv_gradfix_cache[key]
|
| 188 |
+
|
| 189 |
+
# Validate arguments.
|
| 190 |
+
assert groups >= 1
|
| 191 |
+
assert all(stride[i] >= 1 for i in range(ndim))
|
| 192 |
+
assert all(padding[i] >= 0 for i in range(ndim))
|
| 193 |
+
assert all(dilation[i] >= 0 for i in range(ndim))
|
| 194 |
+
if not transpose:
|
| 195 |
+
assert all(output_padding[i] == 0 for i in range(ndim))
|
| 196 |
+
else: # transpose
|
| 197 |
+
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
|
| 198 |
+
|
| 199 |
+
# Helpers.
|
| 200 |
+
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
|
| 201 |
+
|
| 202 |
+
def calc_output_padding(input_shape, output_shape):
|
| 203 |
+
if transpose:
|
| 204 |
+
return [
|
| 205 |
+
0,
|
| 206 |
+
] * ndim
|
| 207 |
+
return [
|
| 208 |
+
input_shape[i + 2]
|
| 209 |
+
- (output_shape[i + 2] - 1) * stride[i]
|
| 210 |
+
- (1 - 2 * padding[i])
|
| 211 |
+
- dilation[i] * (weight_shape[i + 2] - 1)
|
| 212 |
+
for i in range(ndim)
|
| 213 |
+
]
|
| 214 |
+
|
| 215 |
+
# Forward & backward.
|
| 216 |
+
class ConvNd(torch.autograd.Function):
|
| 217 |
+
@staticmethod
|
| 218 |
+
def forward(ctx, input, weight, bias):
|
| 219 |
+
"""
|
| 220 |
+
input size: [B, C, ...]
|
| 221 |
+
weight size:
|
| 222 |
+
-> Conv: [C_out, C_in // groups, ...]
|
| 223 |
+
-> Transpose: [C_in, C_out // groups, ...]
|
| 224 |
+
"""
|
| 225 |
+
assert weight.shape == weight_shape
|
| 226 |
+
ctx.save_for_backward(input, weight)
|
| 227 |
+
|
| 228 |
+
# General case => cuDNN.
|
| 229 |
+
if transpose:
|
| 230 |
+
return getattr(torch.nn.functional, f"conv_transpose{ndim}d")(
|
| 231 |
+
input=input,
|
| 232 |
+
weight=weight.to(input.dtype),
|
| 233 |
+
bias=bias,
|
| 234 |
+
output_padding=output_padding,
|
| 235 |
+
**common_kwargs,
|
| 236 |
+
)
|
| 237 |
+
return getattr(torch.nn.functional, f"conv{ndim}d")(
|
| 238 |
+
input=input, weight=weight.to(input.dtype), bias=bias, **common_kwargs
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
@staticmethod
|
| 242 |
+
def backward(ctx, grad_output):
|
| 243 |
+
input, weight = ctx.saved_tensors
|
| 244 |
+
grad_input = None
|
| 245 |
+
grad_weight = None
|
| 246 |
+
grad_bias = None
|
| 247 |
+
|
| 248 |
+
if ctx.needs_input_grad[0]: # Input
|
| 249 |
+
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
| 250 |
+
op = _conv_gradfix(
|
| 251 |
+
transpose=(not transpose),
|
| 252 |
+
weight_shape=weight_shape,
|
| 253 |
+
output_padding=p,
|
| 254 |
+
**common_kwargs,
|
| 255 |
+
)
|
| 256 |
+
grad_input = op.apply(grad_output, weight, None)
|
| 257 |
+
assert grad_input.shape == input.shape
|
| 258 |
+
|
| 259 |
+
if ctx.needs_input_grad[1] and not weight_gradients_disabled: # Weight
|
| 260 |
+
grad_weight = ConvNdGradWeight.apply(grad_output, input)
|
| 261 |
+
assert grad_weight.shape == weight_shape
|
| 262 |
+
|
| 263 |
+
if ctx.needs_input_grad[2]: # Bias
|
| 264 |
+
grad_bias = grad_output.transpose(0, 1).flatten(1).sum(1)
|
| 265 |
+
|
| 266 |
+
return grad_input, grad_weight, grad_bias
|
| 267 |
+
|
| 268 |
+
# Gradient with respect to the weights.
|
| 269 |
+
class ConvNdGradWeight(torch.autograd.Function):
|
| 270 |
+
@staticmethod
|
| 271 |
+
def forward(ctx, grad_output, input):
|
| 272 |
+
flags = [
|
| 273 |
+
torch.backends.cudnn.benchmark,
|
| 274 |
+
torch.backends.cudnn.deterministic,
|
| 275 |
+
torch.backends.cudnn.allow_tf32,
|
| 276 |
+
]
|
| 277 |
+
if torch.__version__.startswith("1"):
|
| 278 |
+
op = torch._C._jit_get_operation(
|
| 279 |
+
"aten::cudnn_convolution_backward_weight"
|
| 280 |
+
if not transpose
|
| 281 |
+
else "aten::cudnn_convolution_transpose_backward_weight"
|
| 282 |
+
)
|
| 283 |
+
grad_weight = op(
|
| 284 |
+
weight_shape,
|
| 285 |
+
grad_output,
|
| 286 |
+
input.to(grad_output.dtype),
|
| 287 |
+
padding,
|
| 288 |
+
stride,
|
| 289 |
+
dilation,
|
| 290 |
+
groups,
|
| 291 |
+
*flags,
|
| 292 |
+
)
|
| 293 |
+
elif torch.__version__.startswith("2"):
|
| 294 |
+
# https://github.com/pytorch/pytorch/issues/74437
|
| 295 |
+
op, _ = torch._C._jit_get_operation("aten::convolution_backward")
|
| 296 |
+
dummy_weight = torch.tensor(
|
| 297 |
+
0.0, dtype=grad_output.dtype, device=input.device
|
| 298 |
+
).expand(weight_shape)
|
| 299 |
+
grad_weight = op(
|
| 300 |
+
grad_output,
|
| 301 |
+
input.to(grad_output.dtype),
|
| 302 |
+
dummy_weight,
|
| 303 |
+
None,
|
| 304 |
+
stride,
|
| 305 |
+
padding,
|
| 306 |
+
dilation,
|
| 307 |
+
transpose,
|
| 308 |
+
(0,) * ndim,
|
| 309 |
+
groups,
|
| 310 |
+
[False, True, False],
|
| 311 |
+
)[1]
|
| 312 |
+
else:
|
| 313 |
+
raise NotImplementedError
|
| 314 |
+
assert grad_weight.shape == weight_shape
|
| 315 |
+
ctx.save_for_backward(grad_output, input)
|
| 316 |
+
return grad_weight
|
| 317 |
+
|
| 318 |
+
@staticmethod
|
| 319 |
+
def backward(ctx, grad2_grad_weight):
|
| 320 |
+
grad_output, input = ctx.saved_tensors
|
| 321 |
+
grad2_grad_output = None
|
| 322 |
+
grad2_input = None
|
| 323 |
+
|
| 324 |
+
if ctx.needs_input_grad[0]: # Grad of Weight
|
| 325 |
+
grad2_grad_output = ConvNd.apply(input, grad2_grad_weight, None)
|
| 326 |
+
assert grad2_grad_output.shape == grad_output.shape
|
| 327 |
+
|
| 328 |
+
if ctx.needs_input_grad[1]: # Input
|
| 329 |
+
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
| 330 |
+
op = _conv_gradfix(
|
| 331 |
+
transpose=(not transpose),
|
| 332 |
+
weight_shape=weight_shape,
|
| 333 |
+
output_padding=p,
|
| 334 |
+
**common_kwargs,
|
| 335 |
+
)
|
| 336 |
+
grad2_input = op.apply(grad_output, grad2_grad_weight, None)
|
| 337 |
+
assert grad2_input.shape == input.shape
|
| 338 |
+
|
| 339 |
+
return grad2_grad_output, grad2_input
|
| 340 |
+
|
| 341 |
+
_conv_gradfix_cache[key] = ConvNd
|
| 342 |
+
return ConvNd
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# ----------------------------------------------------------------------------
|
adv_grpo/diffusers_patch/__pycache__/sd3_pipeline_with_logprob_fast.cpython-310.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
adv_grpo/diffusers_patch/__pycache__/sd3_sde_with_logprob.cpython-310.pyc
ADDED
|
Binary file (3.54 kB). View file
|
|
|
adv_grpo/diffusers_patch/__pycache__/train_dreambooth_lora_sd3.cpython-310.pyc
ADDED
|
Binary file (2.48 kB). View file
|
|
|
adv_grpo/diffusers_patch/flux_kontext_pipeline_with_logprob.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List, Optional, Union, Callable
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
| 7 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 8 |
+
from diffusers.utils import logging
|
| 9 |
+
from .sd3_sde_with_logprob import sde_step_with_logprob
|
| 10 |
+
|
| 11 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 12 |
+
|
| 13 |
+
PREFERRED_KONTEXT_RESOLUTIONS = [
|
| 14 |
+
(672, 1568),
|
| 15 |
+
(688, 1504),
|
| 16 |
+
(720, 1456),
|
| 17 |
+
(752, 1392),
|
| 18 |
+
(800, 1328),
|
| 19 |
+
(832, 1248),
|
| 20 |
+
(880, 1184),
|
| 21 |
+
(944, 1104),
|
| 22 |
+
(1024, 1024),
|
| 23 |
+
(1104, 944),
|
| 24 |
+
(1184, 880),
|
| 25 |
+
(1248, 832),
|
| 26 |
+
(1328, 800),
|
| 27 |
+
(1392, 752),
|
| 28 |
+
(1456, 720),
|
| 29 |
+
(1504, 688),
|
| 30 |
+
(1568, 672),
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 34 |
+
def retrieve_latents(
|
| 35 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 36 |
+
):
|
| 37 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 38 |
+
return encoder_output.latent_dist.sample(generator)
|
| 39 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 40 |
+
return encoder_output.latent_dist.mode()
|
| 41 |
+
elif hasattr(encoder_output, "latents"):
|
| 42 |
+
return encoder_output.latents
|
| 43 |
+
else:
|
| 44 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 45 |
+
|
| 46 |
+
def calculate_shift(
|
| 47 |
+
image_seq_len,
|
| 48 |
+
base_seq_len: int = 256,
|
| 49 |
+
max_seq_len: int = 4096,
|
| 50 |
+
base_shift: float = 0.5,
|
| 51 |
+
max_shift: float = 1.15,
|
| 52 |
+
):
|
| 53 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 54 |
+
b = base_shift - m * base_seq_len
|
| 55 |
+
mu = image_seq_len * m + b
|
| 56 |
+
return mu
|
| 57 |
+
|
| 58 |
+
@torch.no_grad()
|
| 59 |
+
def pipeline_with_logprob(
|
| 60 |
+
self,
|
| 61 |
+
image: Optional[PipelineImageInput] = None,
|
| 62 |
+
prompt: Union[str, List[str]] = None,
|
| 63 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 64 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 65 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 66 |
+
height: Optional[int] = None,
|
| 67 |
+
width: Optional[int] = None,
|
| 68 |
+
num_inference_steps: int = 28,
|
| 69 |
+
sigmas: Optional[List[float]] = None,
|
| 70 |
+
guidance_scale: float = 3.5,
|
| 71 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 72 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 73 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 74 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 75 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 76 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 77 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 78 |
+
output_type: Optional[str] = "pil",
|
| 79 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 80 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 81 |
+
max_sequence_length: int = 512,
|
| 82 |
+
max_area: int = 1024**2,
|
| 83 |
+
_auto_resize: bool = True,
|
| 84 |
+
noise_level: float = 0.7,
|
| 85 |
+
):
|
| 86 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 87 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 88 |
+
|
| 89 |
+
original_height, original_width = height, width
|
| 90 |
+
aspect_ratio = width / height
|
| 91 |
+
width = round((max_area * aspect_ratio) ** 0.5)
|
| 92 |
+
height = round((max_area / aspect_ratio) ** 0.5)
|
| 93 |
+
|
| 94 |
+
multiple_of = self.vae_scale_factor * 2
|
| 95 |
+
width = width // multiple_of * multiple_of
|
| 96 |
+
height = height // multiple_of * multiple_of
|
| 97 |
+
|
| 98 |
+
if height != original_height or width != original_width:
|
| 99 |
+
logger.warning(
|
| 100 |
+
f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# 1. Check inputs. Raise error if not correct
|
| 104 |
+
self.check_inputs(
|
| 105 |
+
prompt,
|
| 106 |
+
prompt_2,
|
| 107 |
+
height,
|
| 108 |
+
width,
|
| 109 |
+
negative_prompt=negative_prompt,
|
| 110 |
+
negative_prompt_2=negative_prompt_2,
|
| 111 |
+
prompt_embeds=prompt_embeds,
|
| 112 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 113 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 114 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 115 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 116 |
+
max_sequence_length=max_sequence_length,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self._guidance_scale = guidance_scale
|
| 120 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 121 |
+
self._current_timestep = None
|
| 122 |
+
self._interrupt = False
|
| 123 |
+
|
| 124 |
+
# 2. Define call parameters
|
| 125 |
+
if prompt is not None and isinstance(prompt, str):
|
| 126 |
+
batch_size = 1
|
| 127 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 128 |
+
batch_size = len(prompt)
|
| 129 |
+
else:
|
| 130 |
+
batch_size = prompt_embeds.shape[0]
|
| 131 |
+
|
| 132 |
+
device = self._execution_device
|
| 133 |
+
|
| 134 |
+
lora_scale = (
|
| 135 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 136 |
+
)
|
| 137 |
+
(
|
| 138 |
+
prompt_embeds,
|
| 139 |
+
pooled_prompt_embeds,
|
| 140 |
+
text_ids,
|
| 141 |
+
) = self.encode_prompt(
|
| 142 |
+
prompt=prompt,
|
| 143 |
+
prompt_2=prompt_2,
|
| 144 |
+
prompt_embeds=prompt_embeds,
|
| 145 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 146 |
+
device=device,
|
| 147 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 148 |
+
max_sequence_length=max_sequence_length,
|
| 149 |
+
lora_scale=lora_scale,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# 3. Preprocess image
|
| 153 |
+
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
| 154 |
+
image = self.image_processor.resize(image, height, width)
|
| 155 |
+
image = self.image_processor.preprocess(image, height, width)
|
| 156 |
+
# 4. Prepare latent variables
|
| 157 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 158 |
+
latents, image_latents, latent_ids, image_ids = self.prepare_latents(
|
| 159 |
+
image.float(),
|
| 160 |
+
batch_size * num_images_per_prompt,
|
| 161 |
+
num_channels_latents,
|
| 162 |
+
height,
|
| 163 |
+
width,
|
| 164 |
+
prompt_embeds.dtype,
|
| 165 |
+
device,
|
| 166 |
+
generator,
|
| 167 |
+
latents,
|
| 168 |
+
)
|
| 169 |
+
if image_ids is not None:
|
| 170 |
+
latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
|
| 171 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 172 |
+
image_seq_len = latents.shape[1]
|
| 173 |
+
mu = calculate_shift(
|
| 174 |
+
image_seq_len,
|
| 175 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 176 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 177 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 178 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 179 |
+
)
|
| 180 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 181 |
+
self.scheduler,
|
| 182 |
+
num_inference_steps,
|
| 183 |
+
device,
|
| 184 |
+
sigmas=sigmas,
|
| 185 |
+
mu=mu,
|
| 186 |
+
)
|
| 187 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 188 |
+
self._num_timesteps = len(timesteps)
|
| 189 |
+
|
| 190 |
+
# handle guidance
|
| 191 |
+
if self.transformer.config.guidance_embeds:
|
| 192 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 193 |
+
guidance = guidance.expand(latents.shape[0])
|
| 194 |
+
else:
|
| 195 |
+
guidance = None
|
| 196 |
+
|
| 197 |
+
# 6. Prepare image embeddings
|
| 198 |
+
all_latents = [latents]
|
| 199 |
+
all_log_probs = []
|
| 200 |
+
|
| 201 |
+
# 7. Denoising loop
|
| 202 |
+
self.scheduler.set_begin_index(0)
|
| 203 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 204 |
+
for i, t in enumerate(timesteps):
|
| 205 |
+
if self.interrupt:
|
| 206 |
+
continue
|
| 207 |
+
self._current_timestep = t
|
| 208 |
+
latent_model_input = latents
|
| 209 |
+
if image_latents is not None:
|
| 210 |
+
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
| 211 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 212 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 213 |
+
noise_pred = self.transformer(
|
| 214 |
+
hidden_states=latent_model_input,
|
| 215 |
+
timestep=timestep / 1000,
|
| 216 |
+
guidance=guidance,
|
| 217 |
+
pooled_projections=pooled_prompt_embeds,
|
| 218 |
+
encoder_hidden_states=prompt_embeds,
|
| 219 |
+
txt_ids=text_ids,
|
| 220 |
+
img_ids=latent_ids,
|
| 221 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 222 |
+
return_dict=False,
|
| 223 |
+
)[0]
|
| 224 |
+
if noise_pred.isnan().any():
|
| 225 |
+
breakpoint()
|
| 226 |
+
print("log_prob is nan")
|
| 227 |
+
noise_pred = noise_pred[:, : latents.size(1)]
|
| 228 |
+
latents_dtype = latents.dtype
|
| 229 |
+
|
| 230 |
+
latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
|
| 231 |
+
self.scheduler,
|
| 232 |
+
noise_pred.float(),
|
| 233 |
+
t.unsqueeze(0).repeat(latents.shape[0]),
|
| 234 |
+
latents.float(),
|
| 235 |
+
noise_level=noise_level,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
if latents.dtype != latents_dtype:
|
| 239 |
+
latents = latents.to(latents_dtype)
|
| 240 |
+
all_latents.append(latents)
|
| 241 |
+
all_log_probs.append(log_prob)
|
| 242 |
+
# call the callback, if provided
|
| 243 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 244 |
+
progress_bar.update()
|
| 245 |
+
|
| 246 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 247 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 248 |
+
latents = latents.to(dtype=self.vae.dtype)
|
| 249 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 250 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 251 |
+
|
| 252 |
+
# Offload all models
|
| 253 |
+
self.maybe_free_model_hooks()
|
| 254 |
+
|
| 255 |
+
return image, all_latents, latent_ids, text_ids, all_log_probs, image_latents
|
adv_grpo/diffusers_patch/flux_pipeline_with_logprob.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux.py
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List, Optional, Union, Callable
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
| 7 |
+
from .sd3_sde_with_logprob import sde_step_with_logprob
|
| 8 |
+
|
| 9 |
+
def calculate_shift(
|
| 10 |
+
image_seq_len,
|
| 11 |
+
base_seq_len: int = 256,
|
| 12 |
+
max_seq_len: int = 4096,
|
| 13 |
+
base_shift: float = 0.5,
|
| 14 |
+
max_shift: float = 1.15,
|
| 15 |
+
):
|
| 16 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 17 |
+
b = base_shift - m * base_seq_len
|
| 18 |
+
mu = image_seq_len * m + b
|
| 19 |
+
return mu
|
| 20 |
+
|
| 21 |
+
@torch.no_grad()
|
| 22 |
+
def pipeline_with_logprob(
|
| 23 |
+
self,
|
| 24 |
+
prompt: Union[str, List[str]] = None,
|
| 25 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 26 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 27 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 28 |
+
height: Optional[int] = None,
|
| 29 |
+
width: Optional[int] = None,
|
| 30 |
+
num_inference_steps: int = 28,
|
| 31 |
+
sigmas: Optional[List[float]] = None,
|
| 32 |
+
guidance_scale: float = 3.5,
|
| 33 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 34 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 35 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 36 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 37 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 38 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 39 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 40 |
+
output_type: Optional[str] = "pil",
|
| 41 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 42 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 43 |
+
max_sequence_length: int = 512,
|
| 44 |
+
noise_level: float = 0.7,
|
| 45 |
+
):
|
| 46 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 47 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 48 |
+
|
| 49 |
+
# 1. Check inputs. Raise error if not correct
|
| 50 |
+
self.check_inputs(
|
| 51 |
+
prompt,
|
| 52 |
+
prompt_2,
|
| 53 |
+
height,
|
| 54 |
+
width,
|
| 55 |
+
negative_prompt=negative_prompt,
|
| 56 |
+
negative_prompt_2=negative_prompt_2,
|
| 57 |
+
prompt_embeds=prompt_embeds,
|
| 58 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 59 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 60 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 61 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 62 |
+
max_sequence_length=max_sequence_length,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self._guidance_scale = guidance_scale
|
| 66 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 67 |
+
self._current_timestep = None
|
| 68 |
+
self._interrupt = False
|
| 69 |
+
|
| 70 |
+
# 2. Define call parameters
|
| 71 |
+
if prompt is not None and isinstance(prompt, str):
|
| 72 |
+
batch_size = 1
|
| 73 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 74 |
+
batch_size = len(prompt)
|
| 75 |
+
else:
|
| 76 |
+
batch_size = prompt_embeds.shape[0]
|
| 77 |
+
|
| 78 |
+
device = self._execution_device
|
| 79 |
+
|
| 80 |
+
lora_scale = (
|
| 81 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 82 |
+
)
|
| 83 |
+
(
|
| 84 |
+
prompt_embeds,
|
| 85 |
+
pooled_prompt_embeds,
|
| 86 |
+
text_ids,
|
| 87 |
+
) = self.encode_prompt(
|
| 88 |
+
prompt=prompt,
|
| 89 |
+
prompt_2=prompt_2,
|
| 90 |
+
prompt_embeds=prompt_embeds,
|
| 91 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 92 |
+
device=device,
|
| 93 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 94 |
+
max_sequence_length=max_sequence_length,
|
| 95 |
+
lora_scale=lora_scale,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# 4. Prepare latent variables
|
| 99 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 100 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 101 |
+
batch_size * num_images_per_prompt,
|
| 102 |
+
num_channels_latents,
|
| 103 |
+
height,
|
| 104 |
+
width,
|
| 105 |
+
prompt_embeds.dtype,
|
| 106 |
+
device,
|
| 107 |
+
generator,
|
| 108 |
+
latents,
|
| 109 |
+
)
|
| 110 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 111 |
+
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
|
| 112 |
+
sigmas = None
|
| 113 |
+
image_seq_len = latents.shape[1]
|
| 114 |
+
mu = calculate_shift(
|
| 115 |
+
image_seq_len,
|
| 116 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 117 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 118 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 119 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 120 |
+
)
|
| 121 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 122 |
+
self.scheduler,
|
| 123 |
+
num_inference_steps,
|
| 124 |
+
device,
|
| 125 |
+
sigmas=sigmas,
|
| 126 |
+
mu=mu,
|
| 127 |
+
)
|
| 128 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 129 |
+
self._num_timesteps = len(timesteps)
|
| 130 |
+
|
| 131 |
+
# handle guidance
|
| 132 |
+
if self.transformer.config.guidance_embeds:
|
| 133 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 134 |
+
guidance = guidance.expand(latents.shape[0])
|
| 135 |
+
else:
|
| 136 |
+
guidance = None
|
| 137 |
+
|
| 138 |
+
# 6. Prepare image embeddings
|
| 139 |
+
all_latents = [latents]
|
| 140 |
+
all_log_probs = []
|
| 141 |
+
|
| 142 |
+
# 7. Denoising loop
|
| 143 |
+
self.scheduler.set_begin_index(0)
|
| 144 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 145 |
+
for i, t in enumerate(timesteps):
|
| 146 |
+
if self.interrupt:
|
| 147 |
+
continue
|
| 148 |
+
self._current_timestep = t
|
| 149 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 150 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 151 |
+
noise_pred = self.transformer(
|
| 152 |
+
hidden_states=latents,
|
| 153 |
+
timestep=timestep / 1000,
|
| 154 |
+
guidance=guidance,
|
| 155 |
+
pooled_projections=pooled_prompt_embeds,
|
| 156 |
+
encoder_hidden_states=prompt_embeds,
|
| 157 |
+
txt_ids=text_ids,
|
| 158 |
+
img_ids=latent_image_ids,
|
| 159 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 160 |
+
return_dict=False,
|
| 161 |
+
)[0]
|
| 162 |
+
latents_dtype = latents.dtype
|
| 163 |
+
latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
|
| 164 |
+
self.scheduler,
|
| 165 |
+
noise_pred.float(),
|
| 166 |
+
t.unsqueeze(0).repeat(latents.shape[0]),
|
| 167 |
+
latents.float(),
|
| 168 |
+
noise_level=noise_level,
|
| 169 |
+
)
|
| 170 |
+
if latents.dtype != latents_dtype:
|
| 171 |
+
latents = latents.to(latents_dtype)
|
| 172 |
+
all_latents.append(latents)
|
| 173 |
+
all_log_probs.append(log_prob)
|
| 174 |
+
# call the callback, if provided
|
| 175 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 176 |
+
progress_bar.update()
|
| 177 |
+
|
| 178 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 179 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 180 |
+
latents = latents.to(dtype=self.vae.dtype)
|
| 181 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 182 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 183 |
+
|
| 184 |
+
# Offload all models
|
| 185 |
+
self.maybe_free_model_hooks()
|
| 186 |
+
|
| 187 |
+
return image, all_latents, latent_image_ids, text_ids, all_log_probs
|
adv_grpo/diffusers_patch/sd3_pipeline_with_logprob.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
|
| 2 |
+
# with the following modifications:
|
| 3 |
+
# - It uses the patched version of `sde_step_with_logprob` from `sd3_sde_with_logprob.py`.
|
| 4 |
+
# - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step.
|
| 5 |
+
from typing import Any, Dict, List, Optional, Union
|
| 6 |
+
import torch
|
| 7 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
| 8 |
+
from .sd3_sde_with_logprob import sde_step_with_logprob_new as sde_step_with_logprob
|
| 9 |
+
|
| 10 |
+
@torch.no_grad()
|
| 11 |
+
def pipeline_with_logprob(
|
| 12 |
+
self,
|
| 13 |
+
prompt: Union[str, List[str]] = None,
|
| 14 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 15 |
+
prompt_3: Optional[Union[str, List[str]]] = None,
|
| 16 |
+
height: Optional[int] = None,
|
| 17 |
+
width: Optional[int] = None,
|
| 18 |
+
num_inference_steps: int = 28,
|
| 19 |
+
sigmas: Optional[List[float]] = None,
|
| 20 |
+
guidance_scale: float = 7.0,
|
| 21 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 22 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 23 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 24 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 25 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 26 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 27 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 28 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 29 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 30 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 31 |
+
output_type: Optional[str] = "pil",
|
| 32 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 33 |
+
clip_skip: Optional[int] = None,
|
| 34 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 35 |
+
max_sequence_length: int = 256,
|
| 36 |
+
skip_layer_guidance_scale: float = 2.8,
|
| 37 |
+
noise_level: float = 0.7,
|
| 38 |
+
):
|
| 39 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 40 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 41 |
+
|
| 42 |
+
# 1. Check inputs. Raise error if not correct
|
| 43 |
+
self.check_inputs(
|
| 44 |
+
prompt,
|
| 45 |
+
prompt_2,
|
| 46 |
+
prompt_3,
|
| 47 |
+
height,
|
| 48 |
+
width,
|
| 49 |
+
negative_prompt=negative_prompt,
|
| 50 |
+
negative_prompt_2=negative_prompt_2,
|
| 51 |
+
negative_prompt_3=negative_prompt_3,
|
| 52 |
+
prompt_embeds=prompt_embeds,
|
| 53 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 54 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 55 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 56 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 57 |
+
max_sequence_length=max_sequence_length,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
self._guidance_scale = guidance_scale
|
| 61 |
+
self._skip_layer_guidance_scale = skip_layer_guidance_scale
|
| 62 |
+
self._clip_skip = clip_skip
|
| 63 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 64 |
+
self._interrupt = False
|
| 65 |
+
|
| 66 |
+
# 2. Define call parameters
|
| 67 |
+
if prompt is not None and isinstance(prompt, str):
|
| 68 |
+
batch_size = 1
|
| 69 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 70 |
+
batch_size = len(prompt)
|
| 71 |
+
else:
|
| 72 |
+
batch_size = prompt_embeds.shape[0]
|
| 73 |
+
|
| 74 |
+
device = self._execution_device
|
| 75 |
+
|
| 76 |
+
lora_scale = (
|
| 77 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 78 |
+
)
|
| 79 |
+
(
|
| 80 |
+
prompt_embeds,
|
| 81 |
+
negative_prompt_embeds,
|
| 82 |
+
pooled_prompt_embeds,
|
| 83 |
+
negative_pooled_prompt_embeds,
|
| 84 |
+
) = self.encode_prompt(
|
| 85 |
+
prompt=prompt,
|
| 86 |
+
prompt_2=prompt_2,
|
| 87 |
+
prompt_3=prompt_3,
|
| 88 |
+
negative_prompt=negative_prompt,
|
| 89 |
+
negative_prompt_2=negative_prompt_2,
|
| 90 |
+
negative_prompt_3=negative_prompt_3,
|
| 91 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 92 |
+
prompt_embeds=prompt_embeds,
|
| 93 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 94 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 95 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 96 |
+
device=device,
|
| 97 |
+
clip_skip=self.clip_skip,
|
| 98 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 99 |
+
max_sequence_length=max_sequence_length,
|
| 100 |
+
lora_scale=lora_scale,
|
| 101 |
+
)
|
| 102 |
+
if self.do_classifier_free_guidance:
|
| 103 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 104 |
+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 105 |
+
|
| 106 |
+
# 4. Prepare latent variables
|
| 107 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 108 |
+
# latents = self.prepare_latents(
|
| 109 |
+
# batch_size * num_images_per_prompt,
|
| 110 |
+
# num_channels_latents,
|
| 111 |
+
# height,
|
| 112 |
+
# width,
|
| 113 |
+
# prompt_embeds.dtype,
|
| 114 |
+
# device,
|
| 115 |
+
# generator,
|
| 116 |
+
# latents,
|
| 117 |
+
# ).float()
|
| 118 |
+
latents = self.prepare_latents(
|
| 119 |
+
batch_size * num_images_per_prompt,
|
| 120 |
+
num_channels_latents,
|
| 121 |
+
height,
|
| 122 |
+
width,
|
| 123 |
+
prompt_embeds.dtype,
|
| 124 |
+
device,
|
| 125 |
+
generator,
|
| 126 |
+
latents,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# 5. Prepare timesteps
|
| 130 |
+
scheduler_kwargs = {}
|
| 131 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 132 |
+
self.scheduler,
|
| 133 |
+
num_inference_steps,
|
| 134 |
+
device,
|
| 135 |
+
sigmas=sigmas,
|
| 136 |
+
**scheduler_kwargs,
|
| 137 |
+
)
|
| 138 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 139 |
+
self._num_timesteps = len(timesteps)
|
| 140 |
+
|
| 141 |
+
# 6. Prepare image embeddings
|
| 142 |
+
all_latents = [latents]
|
| 143 |
+
all_log_probs = []
|
| 144 |
+
# impor ptbd;
|
| 145 |
+
|
| 146 |
+
# 7. Denoising loop
|
| 147 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 148 |
+
for i, t in enumerate(timesteps):
|
| 149 |
+
if self.interrupt:
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
# expand the latents if we are doing classifier free guidance
|
| 153 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 154 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 155 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 156 |
+
# import pdb; pdb.set_trace()
|
| 157 |
+
noise_pred = self.transformer(
|
| 158 |
+
hidden_states=latent_model_input,
|
| 159 |
+
timestep=timestep,
|
| 160 |
+
encoder_hidden_states=prompt_embeds,
|
| 161 |
+
pooled_projections=pooled_prompt_embeds,
|
| 162 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 163 |
+
return_dict=False,
|
| 164 |
+
)[0]
|
| 165 |
+
# noise_pred = noise_pred.to(prompt_embeds.dtype)
|
| 166 |
+
# perform guidance
|
| 167 |
+
if self.do_classifier_free_guidance:
|
| 168 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 169 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 170 |
+
|
| 171 |
+
latents_dtype = latents.dtype
|
| 172 |
+
|
| 173 |
+
latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
|
| 174 |
+
self.scheduler,
|
| 175 |
+
noise_pred.float(),
|
| 176 |
+
t.unsqueeze(0),
|
| 177 |
+
latents.float(),
|
| 178 |
+
noise_level=noise_level,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
all_latents.append(latents)
|
| 182 |
+
all_log_probs.append(log_prob)
|
| 183 |
+
if latents.dtype != latents_dtype:
|
| 184 |
+
latents = latents.to(latents_dtype)
|
| 185 |
+
|
| 186 |
+
# call the callback, if provided
|
| 187 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 188 |
+
progress_bar.update()
|
| 189 |
+
|
| 190 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 191 |
+
latents = latents.to(dtype=self.vae.dtype)
|
| 192 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 193 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 194 |
+
|
| 195 |
+
# Offload all models
|
| 196 |
+
self.maybe_free_model_hooks()
|
| 197 |
+
|
| 198 |
+
return image, all_latents, all_log_probs
|
adv_grpo/diffusers_patch/sd3_pipeline_with_logprob_fast.py
ADDED
|
@@ -0,0 +1,1081 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
|
| 2 |
+
# with the following modifications:
|
| 3 |
+
# - It uses the patched version of `sde_step_with_logprob` from `sd3_sde_with_logprob.py`.
|
| 4 |
+
# - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step.
|
| 5 |
+
from typing import Any, Dict, List, Optional, Union
|
| 6 |
+
import torch
|
| 7 |
+
import random
|
| 8 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
| 9 |
+
from .sd3_sde_with_logprob import sde_step_with_logprob_new as sde_step_with_logprob
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@torch.no_grad()
|
| 16 |
+
def pipeline_with_logprob(
|
| 17 |
+
self,
|
| 18 |
+
prompt: Union[str, List[str]] = None,
|
| 19 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 20 |
+
prompt_3: Optional[Union[str, List[str]]] = None,
|
| 21 |
+
height: Optional[int] = None,
|
| 22 |
+
width: Optional[int] = None,
|
| 23 |
+
num_inference_steps: int = 28,
|
| 24 |
+
mini_num_image_per_prompt: int = 1,
|
| 25 |
+
sigmas: Optional[List[float]] = None,
|
| 26 |
+
guidance_scale: float = 7.0,
|
| 27 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 28 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 29 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 30 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 31 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 32 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 33 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 34 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 35 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 36 |
+
output_type: Optional[str] = "pil",
|
| 37 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 38 |
+
clip_skip: Optional[int] = None,
|
| 39 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 40 |
+
max_sequence_length: int = 256,
|
| 41 |
+
skip_layer_guidance_scale: float = 2.8,
|
| 42 |
+
noise_level: float = 0.7,
|
| 43 |
+
train_num_steps: int = 1,
|
| 44 |
+
process_index: int = 0,
|
| 45 |
+
sample_num_steps: int = 10,
|
| 46 |
+
random_timestep: Optional[int] = None,
|
| 47 |
+
):
|
| 48 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 49 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 50 |
+
|
| 51 |
+
# 1. Check inputs. Raise error if not correct
|
| 52 |
+
self.check_inputs(
|
| 53 |
+
prompt,
|
| 54 |
+
prompt_2,
|
| 55 |
+
prompt_3,
|
| 56 |
+
height,
|
| 57 |
+
width,
|
| 58 |
+
negative_prompt=negative_prompt,
|
| 59 |
+
negative_prompt_2=negative_prompt_2,
|
| 60 |
+
negative_prompt_3=negative_prompt_3,
|
| 61 |
+
prompt_embeds=prompt_embeds,
|
| 62 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 63 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 64 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 65 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 66 |
+
max_sequence_length=max_sequence_length,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self._guidance_scale = guidance_scale
|
| 70 |
+
self._skip_layer_guidance_scale = skip_layer_guidance_scale
|
| 71 |
+
self._clip_skip = clip_skip
|
| 72 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 73 |
+
self._interrupt = False
|
| 74 |
+
|
| 75 |
+
# 2. Define call parameters
|
| 76 |
+
if prompt is not None and isinstance(prompt, str):
|
| 77 |
+
batch_size = 1
|
| 78 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 79 |
+
batch_size = len(prompt)
|
| 80 |
+
else:
|
| 81 |
+
batch_size = prompt_embeds.shape[0]
|
| 82 |
+
|
| 83 |
+
device = self._execution_device
|
| 84 |
+
|
| 85 |
+
lora_scale = (
|
| 86 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 87 |
+
)
|
| 88 |
+
(
|
| 89 |
+
prompt_embeds,
|
| 90 |
+
negative_prompt_embeds,
|
| 91 |
+
pooled_prompt_embeds,
|
| 92 |
+
negative_pooled_prompt_embeds,
|
| 93 |
+
) = self.encode_prompt(
|
| 94 |
+
prompt=prompt,
|
| 95 |
+
prompt_2=prompt_2,
|
| 96 |
+
prompt_3=prompt_3,
|
| 97 |
+
negative_prompt=negative_prompt,
|
| 98 |
+
negative_prompt_2=negative_prompt_2,
|
| 99 |
+
negative_prompt_3=negative_prompt_3,
|
| 100 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 101 |
+
prompt_embeds=prompt_embeds,
|
| 102 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 103 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 104 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 105 |
+
device=device,
|
| 106 |
+
clip_skip=self.clip_skip,
|
| 107 |
+
max_sequence_length=max_sequence_length,
|
| 108 |
+
lora_scale=lora_scale,
|
| 109 |
+
)
|
| 110 |
+
# import pdb; pdb.set_trace()
|
| 111 |
+
|
| 112 |
+
# 4. Prepare latent variables
|
| 113 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 114 |
+
latents = self.prepare_latents(
|
| 115 |
+
batch_size,
|
| 116 |
+
num_channels_latents,
|
| 117 |
+
height,
|
| 118 |
+
width,
|
| 119 |
+
prompt_embeds.dtype,
|
| 120 |
+
device,
|
| 121 |
+
generator,
|
| 122 |
+
latents,
|
| 123 |
+
).float()
|
| 124 |
+
# import pdb; pdb.set_trace()
|
| 125 |
+
# latents = latents.to(prompt_embeds.dtype)
|
| 126 |
+
|
| 127 |
+
# 5. Prepare timesteps
|
| 128 |
+
scheduler_kwargs = {}
|
| 129 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 130 |
+
self.scheduler,
|
| 131 |
+
num_inference_steps,
|
| 132 |
+
device,
|
| 133 |
+
sigmas=sigmas,
|
| 134 |
+
**scheduler_kwargs,
|
| 135 |
+
)
|
| 136 |
+
# timesteps = timesteps.to(prompt_embeds.dtype)
|
| 137 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 138 |
+
self._num_timesteps = len(timesteps)
|
| 139 |
+
|
| 140 |
+
random.seed(process_index)
|
| 141 |
+
if random_timestep is None:
|
| 142 |
+
random_timestep = random.randint(0, sample_num_steps//2)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# 6. Prepare image embeddings
|
| 146 |
+
all_latents = []
|
| 147 |
+
all_log_probs = []
|
| 148 |
+
all_timesteps = []
|
| 149 |
+
|
| 150 |
+
if self.do_classifier_free_guidance:
|
| 151 |
+
tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 152 |
+
tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 153 |
+
# 7. Denoising loop
|
| 154 |
+
# import pdb; pdb.set_trace()
|
| 155 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 156 |
+
# import pdb; pdb.set_trace()
|
| 157 |
+
for i, t in enumerate(timesteps):
|
| 158 |
+
if i < random_timestep:
|
| 159 |
+
cur_noise_level = 0
|
| 160 |
+
elif i == random_timestep:
|
| 161 |
+
cur_noise_level= noise_level
|
| 162 |
+
# 将latents repeat mini_num_image_per_prompt次
|
| 163 |
+
latents = latents.repeat(mini_num_image_per_prompt, 1, 1, 1)
|
| 164 |
+
prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
|
| 165 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
|
| 166 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
|
| 167 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
|
| 168 |
+
if self.do_classifier_free_guidance:
|
| 169 |
+
tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 170 |
+
tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 171 |
+
all_latents.append(latents)
|
| 172 |
+
elif i > random_timestep and i < random_timestep + train_num_steps:
|
| 173 |
+
cur_noise_level = noise_level
|
| 174 |
+
else:
|
| 175 |
+
cur_noise_level= 0
|
| 176 |
+
# expand the latents if we are doing classifier free guidance
|
| 177 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 178 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 179 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 180 |
+
# import pdb; pdb.set_trace()
|
| 181 |
+
# noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=tem_prompt_embeds,pooled_projections=tem_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs,return_dict=False, )[0]
|
| 182 |
+
noise_pred = self.transformer(
|
| 183 |
+
hidden_states=latent_model_input,
|
| 184 |
+
timestep=timestep,
|
| 185 |
+
encoder_hidden_states=tem_prompt_embeds,
|
| 186 |
+
pooled_projections=tem_pooled_prompt_embeds,
|
| 187 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 188 |
+
return_dict=False,
|
| 189 |
+
)[0]
|
| 190 |
+
# noise_pred = noise_pred.to(prompt_embeds.dtype)
|
| 191 |
+
# perform guidance
|
| 192 |
+
if self.do_classifier_free_guidance:
|
| 193 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 194 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 195 |
+
|
| 196 |
+
latents_dtype = latents.dtype
|
| 197 |
+
|
| 198 |
+
latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
|
| 199 |
+
self.scheduler,
|
| 200 |
+
noise_pred.float(),
|
| 201 |
+
t.unsqueeze(0),
|
| 202 |
+
latents.float(),
|
| 203 |
+
noise_level=cur_noise_level,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# if latents.dtype != latents_dtype:
|
| 207 |
+
# latents = latents.to(latents_dtype)
|
| 208 |
+
|
| 209 |
+
if i >= random_timestep and i < random_timestep + train_num_steps:
|
| 210 |
+
all_latents.append(latents)
|
| 211 |
+
all_log_probs.append(log_prob)
|
| 212 |
+
all_timesteps.append(t.repeat(len(latents)))
|
| 213 |
+
# import pdb; pdb.set_trace()
|
| 214 |
+
|
| 215 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 216 |
+
progress_bar.update()
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 220 |
+
latents = latents.to(dtype=self.vae.dtype)
|
| 221 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 222 |
+
reconstructd_image = self.image_processor.postprocess(image, output_type="pil")
|
| 223 |
+
# reconstructd_image[0].save("0.png")
|
| 224 |
+
# import pdb; pdb.set_trace()
|
| 225 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 226 |
+
|
| 227 |
+
# Offload all models
|
| 228 |
+
self.maybe_free_model_hooks()
|
| 229 |
+
return image, all_latents, all_log_probs, all_timesteps
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@torch.no_grad()
|
| 234 |
+
def pipeline_with_logprob_new(
|
| 235 |
+
self,
|
| 236 |
+
prompt: Union[str, List[str]] = None,
|
| 237 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 238 |
+
prompt_3: Optional[Union[str, List[str]]] = None,
|
| 239 |
+
height: Optional[int] = None,
|
| 240 |
+
width: Optional[int] = None,
|
| 241 |
+
num_inference_steps: int = 28,
|
| 242 |
+
mini_num_image_per_prompt: int = 1,
|
| 243 |
+
sigmas: Optional[List[float]] = None,
|
| 244 |
+
guidance_scale: float = 7.0,
|
| 245 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 246 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 247 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 248 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 249 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 250 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 251 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 252 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 253 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 254 |
+
output_type: Optional[str] = "pil",
|
| 255 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 256 |
+
clip_skip: Optional[int] = None,
|
| 257 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 258 |
+
max_sequence_length: int = 256,
|
| 259 |
+
skip_layer_guidance_scale: float = 2.8,
|
| 260 |
+
noise_level: float = 0.7,
|
| 261 |
+
train_num_steps: int = 1,
|
| 262 |
+
process_index: int = 0,
|
| 263 |
+
sample_num_steps: int = 10,
|
| 264 |
+
random_timestep: Optional[int] = None,
|
| 265 |
+
):
|
| 266 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 267 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 268 |
+
# import pdb; pdb.set_trace()
|
| 269 |
+
|
| 270 |
+
# 1. Check inputs. Raise error if not correct
|
| 271 |
+
self.check_inputs(
|
| 272 |
+
prompt,
|
| 273 |
+
prompt_2,
|
| 274 |
+
prompt_3,
|
| 275 |
+
height,
|
| 276 |
+
width,
|
| 277 |
+
negative_prompt=negative_prompt,
|
| 278 |
+
negative_prompt_2=negative_prompt_2,
|
| 279 |
+
negative_prompt_3=negative_prompt_3,
|
| 280 |
+
prompt_embeds=prompt_embeds,
|
| 281 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 282 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 283 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 284 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 285 |
+
max_sequence_length=max_sequence_length,
|
| 286 |
+
)
|
| 287 |
+
# import pdb; pdb.set_trace()
|
| 288 |
+
|
| 289 |
+
self._guidance_scale = guidance_scale
|
| 290 |
+
self._skip_layer_guidance_scale = skip_layer_guidance_scale
|
| 291 |
+
self._clip_skip = clip_skip
|
| 292 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 293 |
+
self._interrupt = False
|
| 294 |
+
|
| 295 |
+
# 2. Define call parameters
|
| 296 |
+
if prompt is not None and isinstance(prompt, str):
|
| 297 |
+
batch_size = 1
|
| 298 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 299 |
+
batch_size = len(prompt)
|
| 300 |
+
else:
|
| 301 |
+
batch_size = prompt_embeds.shape[0]
|
| 302 |
+
|
| 303 |
+
device = self._execution_device
|
| 304 |
+
|
| 305 |
+
lora_scale = (
|
| 306 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 307 |
+
)
|
| 308 |
+
# import pdb; pdb.set_trace()
|
| 309 |
+
(
|
| 310 |
+
prompt_embeds,
|
| 311 |
+
negative_prompt_embeds,
|
| 312 |
+
pooled_prompt_embeds,
|
| 313 |
+
negative_pooled_prompt_embeds,
|
| 314 |
+
) = self.encode_prompt(
|
| 315 |
+
prompt=prompt,
|
| 316 |
+
prompt_2=prompt_2,
|
| 317 |
+
prompt_3=prompt_3,
|
| 318 |
+
negative_prompt=negative_prompt,
|
| 319 |
+
negative_prompt_2=negative_prompt_2,
|
| 320 |
+
negative_prompt_3=negative_prompt_3,
|
| 321 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 322 |
+
prompt_embeds=prompt_embeds,
|
| 323 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 324 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 325 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 326 |
+
device=device,
|
| 327 |
+
clip_skip=self.clip_skip,
|
| 328 |
+
max_sequence_length=max_sequence_length,
|
| 329 |
+
lora_scale=lora_scale,
|
| 330 |
+
)
|
| 331 |
+
# import pdb; pdb.set_trace()
|
| 332 |
+
|
| 333 |
+
# 4. Prepare latent variables
|
| 334 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 335 |
+
latents = self.prepare_latents(
|
| 336 |
+
batch_size,
|
| 337 |
+
num_channels_latents,
|
| 338 |
+
height,
|
| 339 |
+
width,
|
| 340 |
+
prompt_embeds.dtype,
|
| 341 |
+
device,
|
| 342 |
+
generator,
|
| 343 |
+
latents,
|
| 344 |
+
)
|
| 345 |
+
# import pdb; pdb.set_trace()
|
| 346 |
+
# latents = latents.to(prompt_embeds.dtype)
|
| 347 |
+
|
| 348 |
+
# 5. Prepare timesteps
|
| 349 |
+
scheduler_kwargs = {}
|
| 350 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 351 |
+
self.scheduler,
|
| 352 |
+
num_inference_steps,
|
| 353 |
+
device,
|
| 354 |
+
sigmas=sigmas,
|
| 355 |
+
**scheduler_kwargs,
|
| 356 |
+
)
|
| 357 |
+
# timesteps = timesteps.to(prompt_embeds.dtype)
|
| 358 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 359 |
+
self._num_timesteps = len(timesteps)
|
| 360 |
+
|
| 361 |
+
random.seed(process_index)
|
| 362 |
+
if random_timestep is None:
|
| 363 |
+
random_timestep = random.randint(0, sample_num_steps//2)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# 6. Prepare image embeddings
|
| 367 |
+
all_latents = []
|
| 368 |
+
all_log_probs = []
|
| 369 |
+
all_timesteps = []
|
| 370 |
+
# import pdb; pdb.set_trace()
|
| 371 |
+
|
| 372 |
+
if self.do_classifier_free_guidance:
|
| 373 |
+
tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 374 |
+
tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 375 |
+
# 7. Denoising loop
|
| 376 |
+
# import pdb; pdb.set_trace()
|
| 377 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 378 |
+
# import pdb; pdb.set_trace()
|
| 379 |
+
for i, t in enumerate(timesteps):
|
| 380 |
+
if i < random_timestep:
|
| 381 |
+
cur_noise_level = 0
|
| 382 |
+
elif i == random_timestep:
|
| 383 |
+
cur_noise_level= noise_level
|
| 384 |
+
# 将latents repeat mini_num_image_per_prompt次
|
| 385 |
+
latents = latents.repeat(mini_num_image_per_prompt, 1, 1, 1)
|
| 386 |
+
prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
|
| 387 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
|
| 388 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
|
| 389 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
|
| 390 |
+
if self.do_classifier_free_guidance:
|
| 391 |
+
tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 392 |
+
tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 393 |
+
all_latents.append(latents)
|
| 394 |
+
elif i > random_timestep and i < random_timestep + train_num_steps:
|
| 395 |
+
cur_noise_level = noise_level
|
| 396 |
+
else:
|
| 397 |
+
cur_noise_level= 0
|
| 398 |
+
# expand the latents if we are doing classifier free guidance
|
| 399 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 400 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 401 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 402 |
+
# import pdb; pdb.set_trace()
|
| 403 |
+
# noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=tem_prompt_embeds,pooled_projections=tem_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs,return_dict=False, )[0]
|
| 404 |
+
noise_pred = self.transformer(
|
| 405 |
+
hidden_states=latent_model_input,
|
| 406 |
+
timestep=timestep,
|
| 407 |
+
encoder_hidden_states=tem_prompt_embeds,
|
| 408 |
+
pooled_projections=tem_pooled_prompt_embeds,
|
| 409 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 410 |
+
return_dict=False,
|
| 411 |
+
)[0]
|
| 412 |
+
# noise_pred = noise_pred.to(prompt_embeds.dtype)
|
| 413 |
+
# perform guidance
|
| 414 |
+
if self.do_classifier_free_guidance:
|
| 415 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 416 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 417 |
+
|
| 418 |
+
latents_dtype = latents.dtype
|
| 419 |
+
|
| 420 |
+
latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
|
| 421 |
+
self.scheduler,
|
| 422 |
+
noise_pred.float(),
|
| 423 |
+
t.unsqueeze(0),
|
| 424 |
+
latents.float(),
|
| 425 |
+
noise_level=cur_noise_level,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
if latents.dtype != latents_dtype:
|
| 429 |
+
latents = latents.to(latents_dtype)
|
| 430 |
+
|
| 431 |
+
if i >= random_timestep and i < random_timestep + train_num_steps:
|
| 432 |
+
all_latents.append(latents)
|
| 433 |
+
all_log_probs.append(log_prob)
|
| 434 |
+
all_timesteps.append(t.repeat(len(latents)))
|
| 435 |
+
# import pdb; pdb.set_trace()
|
| 436 |
+
|
| 437 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 438 |
+
progress_bar.update()
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 442 |
+
latents = latents.to(dtype=self.vae.dtype)
|
| 443 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 444 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 445 |
+
|
| 446 |
+
# Offload all models
|
| 447 |
+
self.maybe_free_model_hooks()
|
| 448 |
+
return image, all_latents, all_log_probs, all_timesteps
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
@torch.no_grad()
|
| 454 |
+
def pipeline_with_logprob_random(
|
| 455 |
+
self,
|
| 456 |
+
prompt: Union[str, List[str]] = None,
|
| 457 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 458 |
+
prompt_3: Optional[Union[str, List[str]]] = None,
|
| 459 |
+
height: Optional[int] = None,
|
| 460 |
+
width: Optional[int] = None,
|
| 461 |
+
num_inference_steps: int = 28,
|
| 462 |
+
mini_num_image_per_prompt: int = 1,
|
| 463 |
+
sigmas: Optional[List[float]] = None,
|
| 464 |
+
guidance_scale: float = 7.0,
|
| 465 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 466 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 467 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 468 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 469 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 470 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 471 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 472 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 473 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 474 |
+
output_type: Optional[str] = "pil",
|
| 475 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 476 |
+
clip_skip: Optional[int] = None,
|
| 477 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 478 |
+
max_sequence_length: int = 256,
|
| 479 |
+
skip_layer_guidance_scale: float = 2.8,
|
| 480 |
+
noise_level: float = 0.7,
|
| 481 |
+
train_num_steps: int = 1,
|
| 482 |
+
process_index: int = 0,
|
| 483 |
+
sample_num_steps: int = 10,
|
| 484 |
+
random_timestep: Optional[int] = None,
|
| 485 |
+
):
|
| 486 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 487 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 488 |
+
# import pdb; pdb.set_trace()
|
| 489 |
+
|
| 490 |
+
# 1. Check inputs. Raise error if not correct
|
| 491 |
+
self.check_inputs(
|
| 492 |
+
prompt,
|
| 493 |
+
prompt_2,
|
| 494 |
+
prompt_3,
|
| 495 |
+
height,
|
| 496 |
+
width,
|
| 497 |
+
negative_prompt=negative_prompt,
|
| 498 |
+
negative_prompt_2=negative_prompt_2,
|
| 499 |
+
negative_prompt_3=negative_prompt_3,
|
| 500 |
+
prompt_embeds=prompt_embeds,
|
| 501 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 502 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 503 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 504 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 505 |
+
max_sequence_length=max_sequence_length,
|
| 506 |
+
)
|
| 507 |
+
# import pdb; pdb.set_trace()
|
| 508 |
+
|
| 509 |
+
self._guidance_scale = guidance_scale
|
| 510 |
+
self._skip_layer_guidance_scale = skip_layer_guidance_scale
|
| 511 |
+
self._clip_skip = clip_skip
|
| 512 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 513 |
+
self._interrupt = False
|
| 514 |
+
|
| 515 |
+
# 2. Define call parameters
|
| 516 |
+
if prompt is not None and isinstance(prompt, str):
|
| 517 |
+
batch_size = 1
|
| 518 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 519 |
+
batch_size = len(prompt)
|
| 520 |
+
else:
|
| 521 |
+
batch_size = prompt_embeds.shape[0]
|
| 522 |
+
|
| 523 |
+
device = self._execution_device
|
| 524 |
+
|
| 525 |
+
lora_scale = (
|
| 526 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 527 |
+
)
|
| 528 |
+
# import pdb; pdb.set_trace()
|
| 529 |
+
(
|
| 530 |
+
prompt_embeds,
|
| 531 |
+
negative_prompt_embeds,
|
| 532 |
+
pooled_prompt_embeds,
|
| 533 |
+
negative_pooled_prompt_embeds,
|
| 534 |
+
) = self.encode_prompt(
|
| 535 |
+
prompt=prompt,
|
| 536 |
+
prompt_2=prompt_2,
|
| 537 |
+
prompt_3=prompt_3,
|
| 538 |
+
negative_prompt=negative_prompt,
|
| 539 |
+
negative_prompt_2=negative_prompt_2,
|
| 540 |
+
negative_prompt_3=negative_prompt_3,
|
| 541 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 542 |
+
prompt_embeds=prompt_embeds,
|
| 543 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 544 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 545 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 546 |
+
device=device,
|
| 547 |
+
clip_skip=self.clip_skip,
|
| 548 |
+
max_sequence_length=max_sequence_length,
|
| 549 |
+
lora_scale=lora_scale,
|
| 550 |
+
)
|
| 551 |
+
prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
|
| 552 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
|
| 553 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
|
| 554 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
|
| 555 |
+
# import pdb; pdb.set_trace()
|
| 556 |
+
|
| 557 |
+
# 4. Prepare latent variables
|
| 558 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 559 |
+
latents = self.prepare_latents(
|
| 560 |
+
prompt_embeds.shape[0],
|
| 561 |
+
num_channels_latents,
|
| 562 |
+
height,
|
| 563 |
+
width,
|
| 564 |
+
prompt_embeds.dtype,
|
| 565 |
+
device,
|
| 566 |
+
generator,
|
| 567 |
+
latents,
|
| 568 |
+
)
|
| 569 |
+
# import pdb; pdb.set_trace()
|
| 570 |
+
# latents = latents.to(prompt_embeds.dtype)
|
| 571 |
+
|
| 572 |
+
# 5. Prepare timesteps
|
| 573 |
+
scheduler_kwargs = {}
|
| 574 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 575 |
+
self.scheduler,
|
| 576 |
+
num_inference_steps,
|
| 577 |
+
device,
|
| 578 |
+
sigmas=sigmas,
|
| 579 |
+
**scheduler_kwargs,
|
| 580 |
+
)
|
| 581 |
+
# timesteps = timesteps.to(prompt_embeds.dtype)
|
| 582 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 583 |
+
self._num_timesteps = len(timesteps)
|
| 584 |
+
|
| 585 |
+
random.seed(process_index)
|
| 586 |
+
if random_timestep is None:
|
| 587 |
+
random_timestep = random.randint(0, sample_num_steps//2)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# 6. Prepare image embeddings
|
| 591 |
+
all_latents = []
|
| 592 |
+
all_log_probs = []
|
| 593 |
+
all_timesteps = []
|
| 594 |
+
if self.do_classifier_free_guidance:
|
| 595 |
+
tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 596 |
+
tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 597 |
+
|
| 598 |
+
if self.do_classifier_free_guidance:
|
| 599 |
+
tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 600 |
+
tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 601 |
+
# 7. Denoising loop
|
| 602 |
+
# import pdb; pdb.set_trace()
|
| 603 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 604 |
+
# import pdb; pdb.set_trace()
|
| 605 |
+
for i, t in enumerate(timesteps):
|
| 606 |
+
if i < random_timestep:
|
| 607 |
+
cur_noise_level = 0
|
| 608 |
+
elif i == random_timestep:
|
| 609 |
+
cur_noise_level= noise_level
|
| 610 |
+
# 将latents repeat mini_num_image_per_prompt次
|
| 611 |
+
# latents = latents.repeat(mini_num_image_per_prompt, 1, 1, 1)
|
| 612 |
+
# prompt_embeds = prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
|
| 613 |
+
# pooled_prompt_embeds = pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
|
| 614 |
+
# negative_prompt_embeds = negative_prompt_embeds.repeat(mini_num_image_per_prompt, 1, 1)
|
| 615 |
+
# negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(mini_num_image_per_prompt, 1)
|
| 616 |
+
# if self.do_classifier_free_guidance:
|
| 617 |
+
# tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 618 |
+
# tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 619 |
+
all_latents.append(latents)
|
| 620 |
+
elif i > random_timestep and i < random_timestep + train_num_steps:
|
| 621 |
+
cur_noise_level = noise_level
|
| 622 |
+
else:
|
| 623 |
+
cur_noise_level= 0
|
| 624 |
+
# expand the latents if we are doing classifier free guidance
|
| 625 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 626 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 627 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 628 |
+
# import pdb; pdb.set_trace()
|
| 629 |
+
# noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=tem_prompt_embeds,pooled_projections=tem_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs,return_dict=False, )[0]
|
| 630 |
+
noise_pred = self.transformer(
|
| 631 |
+
hidden_states=latent_model_input,
|
| 632 |
+
timestep=timestep,
|
| 633 |
+
encoder_hidden_states=tem_prompt_embeds,
|
| 634 |
+
pooled_projections=tem_pooled_prompt_embeds,
|
| 635 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 636 |
+
return_dict=False,
|
| 637 |
+
)[0]
|
| 638 |
+
# noise_pred = noise_pred.to(prompt_embeds.dtype)
|
| 639 |
+
# perform guidance
|
| 640 |
+
if self.do_classifier_free_guidance:
|
| 641 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 642 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 643 |
+
|
| 644 |
+
latents_dtype = latents.dtype
|
| 645 |
+
|
| 646 |
+
latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
|
| 647 |
+
self.scheduler,
|
| 648 |
+
noise_pred.float(),
|
| 649 |
+
t.unsqueeze(0),
|
| 650 |
+
latents.float(),
|
| 651 |
+
noise_level=cur_noise_level,
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
if latents.dtype != latents_dtype:
|
| 655 |
+
latents = latents.to(latents_dtype)
|
| 656 |
+
|
| 657 |
+
if i >= random_timestep and i < random_timestep + train_num_steps:
|
| 658 |
+
all_latents.append(latents)
|
| 659 |
+
all_log_probs.append(log_prob)
|
| 660 |
+
all_timesteps.append(t.repeat(len(latents)))
|
| 661 |
+
# import pdb; pdb.set_trace()
|
| 662 |
+
|
| 663 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 664 |
+
progress_bar.update()
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 668 |
+
latents = latents.to(dtype=self.vae.dtype)
|
| 669 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 670 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 671 |
+
|
| 672 |
+
# Offload all models
|
| 673 |
+
self.maybe_free_model_hooks()
|
| 674 |
+
return image, all_latents, all_log_probs, all_timesteps
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
def move_scheduler_to_device(scheduler, device="cuda"):
|
| 679 |
+
for attr_name in dir(scheduler):
|
| 680 |
+
attr = getattr(scheduler, attr_name)
|
| 681 |
+
if isinstance(attr, torch.Tensor):
|
| 682 |
+
setattr(scheduler, attr_name, attr.to(device))
|
| 683 |
+
return scheduler
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
def image_to_latent(pipe, images: Union[Image.Image, List[Image.Image]], device="cuda"):
|
| 687 |
+
# 统一转 list
|
| 688 |
+
if isinstance(images, Image.Image):
|
| 689 |
+
images = [images]
|
| 690 |
+
|
| 691 |
+
preprocess = transforms.Compose([
|
| 692 |
+
transforms.Resize((512, 512)),
|
| 693 |
+
transforms.ToTensor(), # 转 [0,1]
|
| 694 |
+
transforms.Normalize([0.5], [0.5]) # 映射到 [-1,1]
|
| 695 |
+
])
|
| 696 |
+
|
| 697 |
+
# 批量处理
|
| 698 |
+
img_tensors = [preprocess(img) for img in images] # list of [3,512,512]
|
| 699 |
+
img_tensor = torch.stack(img_tensors, dim=0).to(device, dtype=torch.float32) # [B,3,512,512]
|
| 700 |
+
# import pdb; pdb.set_trace()
|
| 701 |
+
|
| 702 |
+
# 过 VAE 编码
|
| 703 |
+
latent = pipe.vae.encode(img_tensor).latent_dist.sample()
|
| 704 |
+
latent = latent * pipe.vae.config.scaling_factor
|
| 705 |
+
return latent.to(torch.bfloat16) # [B,4,64,64] (假设512输入,缩小8倍)
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
| 709 |
+
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
| 710 |
+
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
| 711 |
+
timesteps = timesteps.to(device)
|
| 712 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 713 |
+
|
| 714 |
+
sigma = sigmas[step_indices].flatten()
|
| 715 |
+
while len(sigma.shape) < n_dim:
|
| 716 |
+
sigma = sigma.unsqueeze(-1)
|
| 717 |
+
return sigma
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
@torch.no_grad()
|
| 722 |
+
def flux_to_sd3_denoise(
|
| 723 |
+
self,
|
| 724 |
+
prompt: Union[str, List[str]] = None,
|
| 725 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 726 |
+
prompt_3: Optional[Union[str, List[str]]] = None,
|
| 727 |
+
flux_images=None,
|
| 728 |
+
device="cuda",
|
| 729 |
+
output_type: Optional[str] = "pil",
|
| 730 |
+
num_inference_steps: int = 20,
|
| 731 |
+
guidance_scale: float = 7.0,
|
| 732 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 733 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 734 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 735 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 736 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 737 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 738 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 739 |
+
max_sequence_length: int = 256,
|
| 740 |
+
noise_level: float = 0.7,
|
| 741 |
+
random_timestep: Optional[int] = None,
|
| 742 |
+
noise_timestep_ratio: float = 0.4,
|
| 743 |
+
clip_skip: Optional[int] = None,
|
| 744 |
+
):
|
| 745 |
+
"""
|
| 746 |
+
用 Flux 生成的图像 -> 转 latent -> 加噪 -> 用 SD3 多步去噪
|
| 747 |
+
输出与 pipeline_with_logprob 对齐: image, all_latents, all_log_probs, all_timesteps
|
| 748 |
+
"""
|
| 749 |
+
# 1. 转 latent
|
| 750 |
+
flux_latent = image_to_latent(self, flux_images, device)
|
| 751 |
+
self._guidance_scale = guidance_scale
|
| 752 |
+
self._clip_skip = clip_skip
|
| 753 |
+
|
| 754 |
+
# 2. 准备 scheduler
|
| 755 |
+
noise_scheduler = self.scheduler
|
| 756 |
+
noise_scheduler.set_timesteps(num_inference_steps)
|
| 757 |
+
timesteps = noise_scheduler.timesteps.to(device)
|
| 758 |
+
|
| 759 |
+
# target_idx = torch.tensor([int(noise_timestep_ratio * (len(timesteps) - 1))], device=device)
|
| 760 |
+
target_idx = torch.tensor([noise_timestep_ratio], device=device)
|
| 761 |
+
t = timesteps[target_idx].to(device)
|
| 762 |
+
|
| 763 |
+
noise = torch.randn_like(flux_latent)
|
| 764 |
+
sigmas = get_sigmas(noise_scheduler, t, device, n_dim=flux_latent.ndim, dtype=flux_latent.dtype)
|
| 765 |
+
latents = (1.0 - sigmas) * flux_latent + sigmas * noise
|
| 766 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 767 |
+
if prompt is not None and isinstance(prompt, str):
|
| 768 |
+
batch_size = 1
|
| 769 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 770 |
+
batch_size = len(prompt)
|
| 771 |
+
else:
|
| 772 |
+
batch_size = prompt_embeds.shape[0]
|
| 773 |
+
|
| 774 |
+
# latents = self.prepare_latents(
|
| 775 |
+
# batch_size,
|
| 776 |
+
# num_channels_latents,
|
| 777 |
+
# 512,
|
| 778 |
+
# 512,
|
| 779 |
+
# prompt_embeds.dtype,
|
| 780 |
+
# device,
|
| 781 |
+
# None,
|
| 782 |
+
# None,
|
| 783 |
+
# )
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
# import pdb; pdb.set_trace()
|
| 788 |
+
|
| 789 |
+
# noisy_latent_vis = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 790 |
+
# noisy_latent_vis = noisy_latent_vis.to(dtype=self.vae.dtype)
|
| 791 |
+
|
| 792 |
+
# noisy_image = self.vae.decode(noisy_latent_vis, return_dict=False)[0]
|
| 793 |
+
# noisy_image = self.image_processor.postprocess(noisy_image, output_type="pil")[0]
|
| 794 |
+
|
| 795 |
+
# 保存到本地
|
| 796 |
+
# noisy_image.save("noisy_image.png")
|
| 797 |
+
# import pdb; pdb.set_trace()
|
| 798 |
+
|
| 799 |
+
# 4. Encode prompts (对齐 pipeline_with_logprob 的处理)
|
| 800 |
+
# lora_scale = (
|
| 801 |
+
# self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 802 |
+
# )
|
| 803 |
+
lora_scale = None
|
| 804 |
+
(
|
| 805 |
+
prompt_embeds,
|
| 806 |
+
negative_prompt_embeds,
|
| 807 |
+
pooled_prompt_embeds,
|
| 808 |
+
negative_pooled_prompt_embeds,
|
| 809 |
+
) = self.encode_prompt(
|
| 810 |
+
prompt=prompt,
|
| 811 |
+
prompt_2=prompt_2,
|
| 812 |
+
prompt_3=prompt_3,
|
| 813 |
+
negative_prompt=negative_prompt,
|
| 814 |
+
negative_prompt_2=negative_prompt_2,
|
| 815 |
+
negative_prompt_3=negative_prompt_3,
|
| 816 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 817 |
+
prompt_embeds=prompt_embeds,
|
| 818 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 819 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 820 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 821 |
+
device=device,
|
| 822 |
+
clip_skip=self.clip_skip,
|
| 823 |
+
max_sequence_length=max_sequence_length,
|
| 824 |
+
lora_scale=lora_scale,
|
| 825 |
+
)
|
| 826 |
+
# import pdb; pdb.set_trace()
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
prompt_embeds = prompt_embeds.repeat(latents.shape[0], 1, 1)
|
| 830 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(latents.shape[0], 1)
|
| 831 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(latents.shape[0], 1, 1)
|
| 832 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(latents.shape[0], 1)
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
if self.do_classifier_free_guidance:
|
| 836 |
+
tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 837 |
+
tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 838 |
+
else:
|
| 839 |
+
tem_prompt_embeds = prompt_embeds
|
| 840 |
+
tem_pooled_prompt_embeds = pooled_prompt_embeds
|
| 841 |
+
|
| 842 |
+
# 5. 从当前 t 开始去噪
|
| 843 |
+
noise_scheduler.set_timesteps(num_inference_steps)
|
| 844 |
+
timesteps = noise_scheduler.timesteps.to(device)
|
| 845 |
+
start_idx = (timesteps >= t[0]).nonzero()[-1].item()
|
| 846 |
+
timesteps = timesteps[start_idx:]
|
| 847 |
+
|
| 848 |
+
all_latents, all_log_probs, all_timesteps = [], [], []
|
| 849 |
+
noise_scheduler = move_scheduler_to_device(noise_scheduler, device)
|
| 850 |
+
|
| 851 |
+
for index, t_cur in enumerate(timesteps):
|
| 852 |
+
# import pdb; pdb.set_trace()
|
| 853 |
+
if index==0:
|
| 854 |
+
all_latents.append(latents)
|
| 855 |
+
|
| 856 |
+
if index<2:
|
| 857 |
+
cur_noise_level = noise_level
|
| 858 |
+
else:
|
| 859 |
+
cur_noise_level = 0.0
|
| 860 |
+
|
| 861 |
+
latent_model_input = (
|
| 862 |
+
torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 863 |
+
)
|
| 864 |
+
t_input = t_cur.expand(latent_model_input.shape[0]).to(device)
|
| 865 |
+
|
| 866 |
+
latents_dtype = latents.dtype
|
| 867 |
+
model_pred = self.transformer(
|
| 868 |
+
hidden_states=latent_model_input,
|
| 869 |
+
timestep=t_input,
|
| 870 |
+
encoder_hidden_states=tem_prompt_embeds,
|
| 871 |
+
pooled_projections=tem_pooled_prompt_embeds,
|
| 872 |
+
return_dict=False,
|
| 873 |
+
)[0]
|
| 874 |
+
|
| 875 |
+
if self.do_classifier_free_guidance:
|
| 876 |
+
noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
|
| 877 |
+
model_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 878 |
+
# import pdb; pdb.set_trace()
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
|
| 882 |
+
noise_scheduler,
|
| 883 |
+
model_pred.float(),
|
| 884 |
+
t_cur.repeat(len(latents)),
|
| 885 |
+
latents.float(),
|
| 886 |
+
noise_level=noise_level,
|
| 887 |
+
)
|
| 888 |
+
if latents.dtype != latents_dtype:
|
| 889 |
+
latents = latents.to(latents_dtype)
|
| 890 |
+
|
| 891 |
+
if index>=0 and index<2:
|
| 892 |
+
# if index<2:
|
| 893 |
+
# print(model_pred)
|
| 894 |
+
all_latents.append(latents)
|
| 895 |
+
all_log_probs.append(log_prob)
|
| 896 |
+
all_timesteps.append(t_cur.repeat(len(latents)))
|
| 897 |
+
# import pdb; pdb.set_trace()
|
| 898 |
+
|
| 899 |
+
# 6. 最终解码
|
| 900 |
+
denoised_latent = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 901 |
+
denoised_latent = denoised_latent.to(dtype=self.vae.dtype)
|
| 902 |
+
|
| 903 |
+
image = self.vae.decode(denoised_latent, return_dict=False)[0]
|
| 904 |
+
# reconstructd_image = self.image_processor.postprocess(image, output_type="pil")[0]
|
| 905 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 906 |
+
|
| 907 |
+
return image, all_latents, all_log_probs, all_timesteps
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
|
| 913 |
+
@torch.no_grad()
|
| 914 |
+
def flux_to_sd3_denoise_random(
|
| 915 |
+
self,
|
| 916 |
+
prompt: Union[str, List[str]] = None,
|
| 917 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 918 |
+
prompt_3: Optional[Union[str, List[str]]] = None,
|
| 919 |
+
flux_images=None,
|
| 920 |
+
device="cuda",
|
| 921 |
+
output_type: Optional[str] = "pil",
|
| 922 |
+
num_inference_steps: int = 20,
|
| 923 |
+
guidance_scale: float = 7.0,
|
| 924 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 925 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 926 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 927 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 928 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 929 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 930 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 931 |
+
max_sequence_length: int = 256,
|
| 932 |
+
noise_level: float = 0.7,
|
| 933 |
+
random_timestep: Optional[int] = None,
|
| 934 |
+
noise_timestep_ratio: float = 0.4,
|
| 935 |
+
clip_skip: Optional[int] = None,
|
| 936 |
+
):
|
| 937 |
+
"""
|
| 938 |
+
用 Flux 生成的图像 -> 转 latent -> 加噪 -> 用 SD3 多步去噪
|
| 939 |
+
输出与 pipeline_with_logprob 对齐: image, all_latents, all_log_probs, all_timesteps
|
| 940 |
+
"""
|
| 941 |
+
# 1. 转 latent
|
| 942 |
+
flux_latent = image_to_latent(self, flux_images, device)
|
| 943 |
+
self._guidance_scale = guidance_scale
|
| 944 |
+
self._clip_skip = clip_skip
|
| 945 |
+
|
| 946 |
+
# 2. 准备 scheduler
|
| 947 |
+
noise_scheduler = self.scheduler
|
| 948 |
+
noise_scheduler.set_timesteps(num_inference_steps)
|
| 949 |
+
timesteps = noise_scheduler.timesteps.to(device)
|
| 950 |
+
|
| 951 |
+
# target_idx = torch.tensor([int(noise_timestep_ratio * (len(timesteps) - 1))], device=device)
|
| 952 |
+
# t = timesteps[target_idx].to(device)
|
| 953 |
+
|
| 954 |
+
# noise = torch.randn_like(flux_latent)
|
| 955 |
+
# sigmas = get_sigmas(noise_scheduler, t, device, n_dim=flux_latent.ndim, dtype=flux_latent.dtype)
|
| 956 |
+
# latents = (1.0 - sigmas) * flux_latent + sigmas * noise
|
| 957 |
+
|
| 958 |
+
target_idx = torch.tensor([random.randint(5, 10)], device=device)
|
| 959 |
+
t = timesteps[target_idx].to(device)
|
| 960 |
+
# 生成标准高斯噪声
|
| 961 |
+
noise = torch.randn_like(flux_latent)
|
| 962 |
+
# 获取对应的 sigma
|
| 963 |
+
sigmas = get_sigmas(noise_scheduler, t, device, n_dim=flux_latent.ndim, dtype=flux_latent.dtype)
|
| 964 |
+
# 给 latent 加噪
|
| 965 |
+
latents = (1.0 - sigmas) * flux_latent + sigmas * noise
|
| 966 |
+
|
| 967 |
+
# import pdb; pdb.set_trace()
|
| 968 |
+
|
| 969 |
+
# noisy_latent_vis = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 970 |
+
# noisy_latent_vis = noisy_latent_vis.to(dtype=self.vae.dtype)
|
| 971 |
+
|
| 972 |
+
# noisy_image = self.vae.decode(noisy_latent_vis, return_dict=False)[0]
|
| 973 |
+
# noisy_image = self.image_processor.postprocess(noisy_image, output_type="pil")[0]
|
| 974 |
+
|
| 975 |
+
# 保存到本地
|
| 976 |
+
# noisy_image.save("noisy_image.png")
|
| 977 |
+
# import pdb; pdb.set_trace()
|
| 978 |
+
|
| 979 |
+
# 4. Encode prompts (对齐 pipeline_with_logprob 的处理)
|
| 980 |
+
# lora_scale = (
|
| 981 |
+
# self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 982 |
+
# )
|
| 983 |
+
lora_scale = None
|
| 984 |
+
(
|
| 985 |
+
prompt_embeds,
|
| 986 |
+
negative_prompt_embeds,
|
| 987 |
+
pooled_prompt_embeds,
|
| 988 |
+
negative_pooled_prompt_embeds,
|
| 989 |
+
) = self.encode_prompt(
|
| 990 |
+
prompt=prompt,
|
| 991 |
+
prompt_2=prompt_2,
|
| 992 |
+
prompt_3=prompt_3,
|
| 993 |
+
negative_prompt=negative_prompt,
|
| 994 |
+
negative_prompt_2=negative_prompt_2,
|
| 995 |
+
negative_prompt_3=negative_prompt_3,
|
| 996 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 997 |
+
prompt_embeds=prompt_embeds,
|
| 998 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 999 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 1000 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 1001 |
+
device=device,
|
| 1002 |
+
clip_skip=self.clip_skip,
|
| 1003 |
+
max_sequence_length=max_sequence_length,
|
| 1004 |
+
lora_scale=lora_scale,
|
| 1005 |
+
)
|
| 1006 |
+
# import pdb; pdb.set_trace()
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
prompt_embeds = prompt_embeds.repeat(latents.shape[0], 1, 1)
|
| 1010 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(latents.shape[0], 1)
|
| 1011 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(latents.shape[0], 1, 1)
|
| 1012 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(latents.shape[0], 1)
|
| 1013 |
+
|
| 1014 |
+
|
| 1015 |
+
if self.do_classifier_free_guidance:
|
| 1016 |
+
tem_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 1017 |
+
tem_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 1018 |
+
else:
|
| 1019 |
+
tem_prompt_embeds = prompt_embeds
|
| 1020 |
+
tem_pooled_prompt_embeds = pooled_prompt_embeds
|
| 1021 |
+
|
| 1022 |
+
# 5. 从当前 t 开始去噪
|
| 1023 |
+
noise_scheduler.set_timesteps(num_inference_steps)
|
| 1024 |
+
timesteps = noise_scheduler.timesteps.to(device)
|
| 1025 |
+
start_idx = (timesteps >= t[0]).nonzero()[-1].item()
|
| 1026 |
+
timesteps = timesteps[start_idx:]
|
| 1027 |
+
|
| 1028 |
+
all_latents, all_log_probs, all_timesteps = [], [], []
|
| 1029 |
+
noise_scheduler = move_scheduler_to_device(noise_scheduler, device)
|
| 1030 |
+
|
| 1031 |
+
for index, t_cur in enumerate(timesteps):
|
| 1032 |
+
if index==0:
|
| 1033 |
+
all_latents.append(latents)
|
| 1034 |
+
|
| 1035 |
+
latent_model_input = (
|
| 1036 |
+
torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 1037 |
+
)
|
| 1038 |
+
t_input = t_cur.expand(latent_model_input.shape[0]).to(device)
|
| 1039 |
+
|
| 1040 |
+
latents_dtype = latents.dtype
|
| 1041 |
+
model_pred = self.transformer(
|
| 1042 |
+
hidden_states=latent_model_input,
|
| 1043 |
+
timestep=t_input,
|
| 1044 |
+
encoder_hidden_states=tem_prompt_embeds,
|
| 1045 |
+
pooled_projections=tem_pooled_prompt_embeds,
|
| 1046 |
+
return_dict=False,
|
| 1047 |
+
)[0]
|
| 1048 |
+
|
| 1049 |
+
if self.do_classifier_free_guidance:
|
| 1050 |
+
noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
|
| 1051 |
+
model_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1052 |
+
# import pdb; pdb.set_trace()
|
| 1053 |
+
|
| 1054 |
+
|
| 1055 |
+
latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
|
| 1056 |
+
noise_scheduler,
|
| 1057 |
+
model_pred.float(),
|
| 1058 |
+
t_cur.repeat(len(latents)),
|
| 1059 |
+
latents.float(),
|
| 1060 |
+
noise_level=noise_level,
|
| 1061 |
+
)
|
| 1062 |
+
if latents.dtype != latents_dtype:
|
| 1063 |
+
latents = latents.to(latents_dtype)
|
| 1064 |
+
|
| 1065 |
+
# if index>=2 and index<4:
|
| 1066 |
+
if index<2:
|
| 1067 |
+
# print(model_pred)
|
| 1068 |
+
all_latents.append(latents)
|
| 1069 |
+
all_log_probs.append(log_prob)
|
| 1070 |
+
all_timesteps.append(t_cur.repeat(len(latents)))
|
| 1071 |
+
# import pdb; pdb.set_trace()
|
| 1072 |
+
|
| 1073 |
+
# 6. 最终解码
|
| 1074 |
+
denoised_latent = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1075 |
+
denoised_latent = denoised_latent.to(dtype=self.vae.dtype)
|
| 1076 |
+
|
| 1077 |
+
image = self.vae.decode(denoised_latent, return_dict=False)[0]
|
| 1078 |
+
# reconstructd_image = self.image_processor.postprocess(image, output_type="pil")[0]
|
| 1079 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1080 |
+
|
| 1081 |
+
return image, all_latents, all_log_probs, all_timesteps
|
adv_grpo/diffusers_patch/sd3_sde_with_logprob.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py
|
| 2 |
+
# We adapt it from flow to flow matching.
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from typing import Optional, Union
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 9 |
+
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def sde_step_with_logprob(
|
| 14 |
+
self: FlowMatchEulerDiscreteScheduler,
|
| 15 |
+
model_output: torch.FloatTensor,
|
| 16 |
+
timestep: Union[float, torch.FloatTensor],
|
| 17 |
+
sample: torch.FloatTensor,
|
| 18 |
+
noise_level: float = 0.7,
|
| 19 |
+
prev_sample: Optional[torch.FloatTensor] = None,
|
| 20 |
+
generator: Optional[torch.Generator] = None,
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
|
| 24 |
+
process from the learned model outputs (most often the predicted velocity).
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
model_output (`torch.FloatTensor`):
|
| 28 |
+
The direct output from learned flow model.
|
| 29 |
+
timestep (`float`):
|
| 30 |
+
The current discrete timestep in the diffusion chain.
|
| 31 |
+
sample (`torch.FloatTensor`):
|
| 32 |
+
A current instance of a sample created by the diffusion process.
|
| 33 |
+
generator (`torch.Generator`, *optional*):
|
| 34 |
+
A random number generator.
|
| 35 |
+
"""
|
| 36 |
+
# bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32
|
| 37 |
+
model_output=model_output.float()
|
| 38 |
+
sample=sample.float()
|
| 39 |
+
if prev_sample is not None:
|
| 40 |
+
prev_sample=prev_sample.float()
|
| 41 |
+
|
| 42 |
+
step_index = [self.index_for_timestep(t) for t in timestep]
|
| 43 |
+
prev_step_index = [step+1 for step in step_index]
|
| 44 |
+
sigma = self.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1)))
|
| 45 |
+
sigma_prev = self.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1)))
|
| 46 |
+
sigma_max = self.sigmas[1].item()
|
| 47 |
+
dt = sigma_prev - sigma
|
| 48 |
+
|
| 49 |
+
std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level
|
| 50 |
+
# import pdb; pdb.set_trace()
|
| 51 |
+
|
| 52 |
+
# our sde
|
| 53 |
+
prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
|
| 54 |
+
|
| 55 |
+
if prev_sample is None:
|
| 56 |
+
variance_noise = randn_tensor(
|
| 57 |
+
model_output.shape,
|
| 58 |
+
generator=generator,
|
| 59 |
+
device=model_output.device,
|
| 60 |
+
dtype=model_output.dtype,
|
| 61 |
+
)
|
| 62 |
+
prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
|
| 63 |
+
|
| 64 |
+
log_prob = (
|
| 65 |
+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
|
| 66 |
+
- torch.log(std_dev_t * torch.sqrt(-1*dt))
|
| 67 |
+
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# mean along all but batch dimension
|
| 71 |
+
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
| 72 |
+
|
| 73 |
+
return prev_sample, log_prob, prev_sample_mean, std_dev_t
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def sde_step_with_logprob_new(
|
| 78 |
+
self: FlowMatchEulerDiscreteScheduler,
|
| 79 |
+
model_output: torch.FloatTensor,
|
| 80 |
+
timestep: Union[float, torch.FloatTensor],
|
| 81 |
+
sample: torch.FloatTensor,
|
| 82 |
+
noise_level: float = 0.7,
|
| 83 |
+
prev_sample: Optional[torch.FloatTensor] = None,
|
| 84 |
+
generator: Optional[torch.Generator] = None,
|
| 85 |
+
):
|
| 86 |
+
"""
|
| 87 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
|
| 88 |
+
process from the learned model outputs (most often the predicted velocity).
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
model_output (`torch.FloatTensor`):
|
| 92 |
+
The direct output from learned flow model.
|
| 93 |
+
timestep (`float`):
|
| 94 |
+
The current discrete timestep in the diffusion chain.
|
| 95 |
+
sample (`torch.FloatTensor`):
|
| 96 |
+
A current instance of a sample created by the diffusion process.
|
| 97 |
+
generator (`torch.Generator`, *optional*):
|
| 98 |
+
A random number generator.
|
| 99 |
+
"""
|
| 100 |
+
# bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32
|
| 101 |
+
model_output=model_output.float()
|
| 102 |
+
sample=sample.float()
|
| 103 |
+
if prev_sample is not None:
|
| 104 |
+
prev_sample=prev_sample.float()
|
| 105 |
+
|
| 106 |
+
step_index = [self.index_for_timestep(t) for t in timestep]
|
| 107 |
+
prev_step_index = [step+1 for step in step_index]
|
| 108 |
+
sigma = self.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1)))
|
| 109 |
+
sigma_prev = self.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1)))
|
| 110 |
+
sigma_max = self.sigmas[1].item()
|
| 111 |
+
dt = sigma_prev - sigma
|
| 112 |
+
|
| 113 |
+
# Flow-SDE
|
| 114 |
+
#std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level * torch.sqrt(-1*dt)
|
| 115 |
+
# prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
|
| 116 |
+
|
| 117 |
+
# Flow-CPS
|
| 118 |
+
std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2) # sigma_t in paper
|
| 119 |
+
pred_original_sample = sample - sigma * model_output # predicted x_0 in paper
|
| 120 |
+
noise_estimate = sample + model_output * (1 - sigma) # predicted x_1 in paper
|
| 121 |
+
prev_sample_mean = pred_original_sample * (1 - sigma_prev) + noise_estimate * torch.sqrt(sigma_prev**2 - std_dev_t**2)
|
| 122 |
+
# import pdb; pdb.set_trace()
|
| 123 |
+
|
| 124 |
+
if prev_sample is None:
|
| 125 |
+
variance_noise = randn_tensor(
|
| 126 |
+
model_output.shape,
|
| 127 |
+
generator=generator,
|
| 128 |
+
device=model_output.device,
|
| 129 |
+
dtype=model_output.dtype,
|
| 130 |
+
)
|
| 131 |
+
prev_sample = prev_sample_mean + std_dev_t * variance_noise
|
| 132 |
+
|
| 133 |
+
# remove all constants
|
| 134 |
+
log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2)
|
| 135 |
+
|
| 136 |
+
# mean along all but batch dimension
|
| 137 |
+
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
| 138 |
+
|
| 139 |
+
return prev_sample, log_prob, prev_sample_mean, std_dev_t
|
adv_grpo/diffusers_patch/train_dreambooth_lora_flux.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _encode_prompt_with_t5(
|
| 20 |
+
text_encoder,
|
| 21 |
+
tokenizer,
|
| 22 |
+
max_sequence_length=512,
|
| 23 |
+
prompt=None,
|
| 24 |
+
num_images_per_prompt=1,
|
| 25 |
+
device=None,
|
| 26 |
+
text_input_ids=None,
|
| 27 |
+
):
|
| 28 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 29 |
+
batch_size = len(prompt)
|
| 30 |
+
|
| 31 |
+
if tokenizer is not None:
|
| 32 |
+
text_inputs = tokenizer(
|
| 33 |
+
prompt,
|
| 34 |
+
padding="max_length",
|
| 35 |
+
max_length=max_sequence_length,
|
| 36 |
+
truncation=True,
|
| 37 |
+
return_length=False,
|
| 38 |
+
return_overflowing_tokens=False,
|
| 39 |
+
return_tensors="pt",
|
| 40 |
+
)
|
| 41 |
+
text_input_ids = text_inputs.input_ids
|
| 42 |
+
else:
|
| 43 |
+
if text_input_ids is None:
|
| 44 |
+
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
| 45 |
+
|
| 46 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
| 47 |
+
|
| 48 |
+
if hasattr(text_encoder, "module"):
|
| 49 |
+
dtype = text_encoder.module.dtype
|
| 50 |
+
else:
|
| 51 |
+
dtype = text_encoder.dtype
|
| 52 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 53 |
+
|
| 54 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 55 |
+
|
| 56 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 57 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 58 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 59 |
+
|
| 60 |
+
return prompt_embeds
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _encode_prompt_with_clip(
|
| 64 |
+
text_encoder,
|
| 65 |
+
tokenizer,
|
| 66 |
+
prompt: str,
|
| 67 |
+
device=None,
|
| 68 |
+
text_input_ids=None,
|
| 69 |
+
num_images_per_prompt: int = 1,
|
| 70 |
+
):
|
| 71 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 72 |
+
batch_size = len(prompt)
|
| 73 |
+
|
| 74 |
+
if tokenizer is not None:
|
| 75 |
+
text_inputs = tokenizer(
|
| 76 |
+
prompt,
|
| 77 |
+
padding="max_length",
|
| 78 |
+
max_length=77,
|
| 79 |
+
truncation=True,
|
| 80 |
+
return_overflowing_tokens=False,
|
| 81 |
+
return_length=False,
|
| 82 |
+
return_tensors="pt",
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
text_input_ids = text_inputs.input_ids
|
| 86 |
+
else:
|
| 87 |
+
if text_input_ids is None:
|
| 88 |
+
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
| 89 |
+
|
| 90 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
| 91 |
+
|
| 92 |
+
if hasattr(text_encoder, "module"):
|
| 93 |
+
dtype = text_encoder.module.dtype
|
| 94 |
+
else:
|
| 95 |
+
dtype = text_encoder.dtype
|
| 96 |
+
# Use pooled output of CLIPTextModel
|
| 97 |
+
prompt_embeds = prompt_embeds.pooler_output
|
| 98 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 99 |
+
|
| 100 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 101 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 102 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 103 |
+
|
| 104 |
+
return prompt_embeds
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def encode_prompt(
|
| 108 |
+
text_encoders,
|
| 109 |
+
tokenizers,
|
| 110 |
+
prompt: str,
|
| 111 |
+
max_sequence_length,
|
| 112 |
+
device=None,
|
| 113 |
+
num_images_per_prompt: int = 1,
|
| 114 |
+
text_input_ids_list=None,
|
| 115 |
+
):
|
| 116 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 117 |
+
|
| 118 |
+
if hasattr(text_encoders[0], "module"):
|
| 119 |
+
dtype = text_encoders[0].module.dtype
|
| 120 |
+
else:
|
| 121 |
+
dtype = text_encoders[0].dtype
|
| 122 |
+
|
| 123 |
+
pooled_prompt_embeds = _encode_prompt_with_clip(
|
| 124 |
+
text_encoder=text_encoders[0],
|
| 125 |
+
tokenizer=tokenizers[0],
|
| 126 |
+
prompt=prompt,
|
| 127 |
+
device=device if device is not None else text_encoders[0].device,
|
| 128 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 129 |
+
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
prompt_embeds = _encode_prompt_with_t5(
|
| 133 |
+
text_encoder=text_encoders[1],
|
| 134 |
+
tokenizer=tokenizers[1],
|
| 135 |
+
max_sequence_length=max_sequence_length,
|
| 136 |
+
prompt=prompt,
|
| 137 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 138 |
+
device=device if device is not None else text_encoders[1].device,
|
| 139 |
+
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 143 |
+
|
| 144 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
adv_grpo/diffusers_patch/train_dreambooth_lora_sd3.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _encode_prompt_with_t5(
|
| 20 |
+
text_encoder,
|
| 21 |
+
tokenizer,
|
| 22 |
+
max_sequence_length,
|
| 23 |
+
prompt=None,
|
| 24 |
+
num_images_per_prompt=1,
|
| 25 |
+
device=None,
|
| 26 |
+
text_input_ids=None,
|
| 27 |
+
):
|
| 28 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 29 |
+
batch_size = len(prompt)
|
| 30 |
+
|
| 31 |
+
if tokenizer is not None:
|
| 32 |
+
text_inputs = tokenizer(
|
| 33 |
+
prompt,
|
| 34 |
+
padding="max_length",
|
| 35 |
+
max_length=max_sequence_length,
|
| 36 |
+
truncation=True,
|
| 37 |
+
add_special_tokens=True,
|
| 38 |
+
return_tensors="pt",
|
| 39 |
+
)
|
| 40 |
+
text_input_ids = text_inputs.input_ids
|
| 41 |
+
else:
|
| 42 |
+
if text_input_ids is None:
|
| 43 |
+
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
| 44 |
+
|
| 45 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
| 46 |
+
|
| 47 |
+
dtype = text_encoder.dtype
|
| 48 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 49 |
+
|
| 50 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 51 |
+
|
| 52 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 53 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 54 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 55 |
+
|
| 56 |
+
return prompt_embeds
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _encode_prompt_with_clip(
|
| 60 |
+
text_encoder,
|
| 61 |
+
tokenizer,
|
| 62 |
+
prompt: str,
|
| 63 |
+
device=None,
|
| 64 |
+
text_input_ids=None,
|
| 65 |
+
num_images_per_prompt: int = 1,
|
| 66 |
+
):
|
| 67 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 68 |
+
batch_size = len(prompt)
|
| 69 |
+
|
| 70 |
+
if tokenizer is not None:
|
| 71 |
+
text_inputs = tokenizer(
|
| 72 |
+
prompt,
|
| 73 |
+
padding="max_length",
|
| 74 |
+
max_length=77,
|
| 75 |
+
truncation=True,
|
| 76 |
+
return_tensors="pt",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
text_input_ids = text_inputs.input_ids
|
| 80 |
+
else:
|
| 81 |
+
if text_input_ids is None:
|
| 82 |
+
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
| 83 |
+
|
| 84 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
| 85 |
+
|
| 86 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
| 87 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 88 |
+
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
| 89 |
+
|
| 90 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 91 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 92 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 93 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 94 |
+
|
| 95 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def encode_prompt(
|
| 99 |
+
text_encoders,
|
| 100 |
+
tokenizers,
|
| 101 |
+
prompt: str,
|
| 102 |
+
max_sequence_length,
|
| 103 |
+
device=None,
|
| 104 |
+
num_images_per_prompt: int = 1,
|
| 105 |
+
text_input_ids_list=None,
|
| 106 |
+
):
|
| 107 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 108 |
+
|
| 109 |
+
clip_tokenizers = tokenizers[:2]
|
| 110 |
+
clip_text_encoders = text_encoders[:2]
|
| 111 |
+
|
| 112 |
+
clip_prompt_embeds_list = []
|
| 113 |
+
clip_pooled_prompt_embeds_list = []
|
| 114 |
+
for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
|
| 115 |
+
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
|
| 116 |
+
text_encoder=text_encoder,
|
| 117 |
+
tokenizer=tokenizer,
|
| 118 |
+
prompt=prompt,
|
| 119 |
+
device=device if device is not None else text_encoder.device,
|
| 120 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 121 |
+
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
|
| 122 |
+
)
|
| 123 |
+
clip_prompt_embeds_list.append(prompt_embeds)
|
| 124 |
+
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
| 125 |
+
|
| 126 |
+
clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
|
| 127 |
+
pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
|
| 128 |
+
|
| 129 |
+
t5_prompt_embed = _encode_prompt_with_t5(
|
| 130 |
+
text_encoders[-1],
|
| 131 |
+
tokenizers[-1],
|
| 132 |
+
max_sequence_length,
|
| 133 |
+
prompt=prompt,
|
| 134 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 135 |
+
text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
|
| 136 |
+
device=device if device is not None else text_encoders[-1].device,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
clip_prompt_embeds = torch.nn.functional.pad(
|
| 140 |
+
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
| 141 |
+
)
|
| 142 |
+
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
| 143 |
+
|
| 144 |
+
return prompt_embeds, pooled_prompt_embeds
|
adv_grpo/diffusers_patch/wan_pipeline_with_logprob.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
| 2 |
+
import torch
|
| 3 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 4 |
+
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
|
| 5 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 6 |
+
import math
|
| 7 |
+
import numpy as np
|
| 8 |
+
# import logger
|
| 9 |
+
|
| 10 |
+
def sde_step_with_logprob(
|
| 11 |
+
self: UniPCMultistepScheduler,
|
| 12 |
+
model_output: torch.FloatTensor,
|
| 13 |
+
timestep: Union[float, torch.FloatTensor],
|
| 14 |
+
sample: torch.FloatTensor,
|
| 15 |
+
prev_sample: Optional[torch.FloatTensor] = None,
|
| 16 |
+
generator: Optional[torch.Generator] = None,
|
| 17 |
+
determistic: bool = False,
|
| 18 |
+
return_pixel_log_prob: bool = False,
|
| 19 |
+
return_dt_and_std_dev_t: bool = False
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
|
| 23 |
+
process from the learned model outputs (most often the predicted velocity).
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
model_output (`torch.FloatTensor`):
|
| 27 |
+
The direct output from learned flow model.
|
| 28 |
+
timestep (`float`):
|
| 29 |
+
The current discrete timestep in the diffusion chain.
|
| 30 |
+
sample (`torch.FloatTensor`):
|
| 31 |
+
A current instance of a sample created by the diffusion process.
|
| 32 |
+
generator (`torch.Generator`, *optional*):
|
| 33 |
+
A random number generator.
|
| 34 |
+
"""
|
| 35 |
+
# prev_sample_mean, we must convert all variable to fp32
|
| 36 |
+
model_output=model_output.float()
|
| 37 |
+
sample=sample.float()
|
| 38 |
+
if prev_sample is not None:
|
| 39 |
+
prev_sample=prev_sample.float()
|
| 40 |
+
|
| 41 |
+
step_index = [self.index_for_timestep(t) for t in timestep]
|
| 42 |
+
prev_step_index = [step+1 for step in step_index]
|
| 43 |
+
|
| 44 |
+
self.sigmas = self.sigmas.to(sample.device)
|
| 45 |
+
sigma = self.sigmas[step_index].view(-1, 1, 1, 1, 1)
|
| 46 |
+
sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1, 1, 1)
|
| 47 |
+
sigma_max = self.sigmas[1].item()
|
| 48 |
+
sigma_min = self.sigmas[-1].item()
|
| 49 |
+
dt = sigma_prev - sigma
|
| 50 |
+
|
| 51 |
+
std_dev_t = sigma_min + (sigma_max - sigma_min) * sigma
|
| 52 |
+
prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
|
| 53 |
+
|
| 54 |
+
if prev_sample is not None and generator is not None:
|
| 55 |
+
raise ValueError(
|
| 56 |
+
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
|
| 57 |
+
" `prev_sample` stays `None`."
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
if prev_sample is None:
|
| 61 |
+
variance_noise = randn_tensor(
|
| 62 |
+
model_output.shape,
|
| 63 |
+
generator=generator,
|
| 64 |
+
device=model_output.device,
|
| 65 |
+
dtype=model_output.dtype,
|
| 66 |
+
)
|
| 67 |
+
prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
|
| 68 |
+
|
| 69 |
+
# No noise is added during evaluation
|
| 70 |
+
if determistic:
|
| 71 |
+
prev_sample = sample + dt * model_output
|
| 72 |
+
|
| 73 |
+
log_prob = (
|
| 74 |
+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
|
| 75 |
+
- torch.log(std_dev_t * torch.sqrt(-1*dt))
|
| 76 |
+
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# mean along all but batch dimension
|
| 80 |
+
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
| 81 |
+
|
| 82 |
+
if return_dt_and_std_dev_t:
|
| 83 |
+
return prev_sample, log_prob, prev_sample_mean, std_dev_t, torch.sqrt(-1*dt)
|
| 84 |
+
return prev_sample, log_prob, prev_sample_mean, std_dev_t * torch.sqrt(-1*dt)
|
| 85 |
+
|
| 86 |
+
def wan_pipeline_with_logprob(
|
| 87 |
+
self,
|
| 88 |
+
prompt: Union[str, List[str]] = None,
|
| 89 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 90 |
+
height: int = 480,
|
| 91 |
+
width: int = 832,
|
| 92 |
+
num_frames: int = 81,
|
| 93 |
+
num_inference_steps: int = 50,
|
| 94 |
+
guidance_scale: float = 5.0,
|
| 95 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 96 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 97 |
+
latents: Optional[torch.Tensor] = None,
|
| 98 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 99 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 100 |
+
output_type: Optional[str] = "np",
|
| 101 |
+
return_dict: bool = True,
|
| 102 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 103 |
+
callback_on_step_end: Optional[
|
| 104 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 105 |
+
] = None,
|
| 106 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 107 |
+
max_sequence_length: int = 512,
|
| 108 |
+
determistic: bool = False,
|
| 109 |
+
kl_reward: float = 0.0,
|
| 110 |
+
return_pixel_log_prob: bool = False,
|
| 111 |
+
):
|
| 112 |
+
r"""
|
| 113 |
+
The call function to the pipeline for generation.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 117 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 118 |
+
instead.
|
| 119 |
+
height (`int`, defaults to `480`):
|
| 120 |
+
The height in pixels of the generated image.
|
| 121 |
+
width (`int`, defaults to `832`):
|
| 122 |
+
The width in pixels of the generated image.
|
| 123 |
+
num_frames (`int`, defaults to `81`):
|
| 124 |
+
The number of frames in the generated video.
|
| 125 |
+
num_inference_steps (`int`, defaults to `50`):
|
| 126 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 127 |
+
expense of slower inference.
|
| 128 |
+
guidance_scale (`float`, defaults to `5.0`):
|
| 129 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 130 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 131 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 132 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 133 |
+
usually at the expense of lower image quality.
|
| 134 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 135 |
+
The number of images to generate per prompt.
|
| 136 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 137 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 138 |
+
generation deterministic.
|
| 139 |
+
latents (`torch.Tensor`, *optional*):
|
| 140 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 141 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 142 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 143 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 144 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 145 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 146 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 147 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 148 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 149 |
+
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
|
| 150 |
+
attention_kwargs (`dict`, *optional*):
|
| 151 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 152 |
+
`self.processor` in
|
| 153 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 154 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 155 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 156 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 157 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 158 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 159 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 160 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 161 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 162 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 163 |
+
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
| 164 |
+
The dtype to use for the torch.amp.autocast.
|
| 165 |
+
|
| 166 |
+
Examples:
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
[`~WanPipelineOutput`] or `tuple`:
|
| 170 |
+
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
| 171 |
+
the first element is a list with the generated images and the second element is a list of `bool`s
|
| 172 |
+
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 176 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 177 |
+
|
| 178 |
+
# 1. Check inputs. Raise error if not correct
|
| 179 |
+
self.check_inputs(
|
| 180 |
+
prompt,
|
| 181 |
+
negative_prompt,
|
| 182 |
+
height,
|
| 183 |
+
width,
|
| 184 |
+
prompt_embeds,
|
| 185 |
+
negative_prompt_embeds,
|
| 186 |
+
callback_on_step_end_tensor_inputs,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if num_frames % self.vae_scale_factor_temporal != 1:
|
| 190 |
+
print(
|
| 191 |
+
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
| 192 |
+
)
|
| 193 |
+
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
| 194 |
+
num_frames = max(num_frames, 1)
|
| 195 |
+
|
| 196 |
+
self._guidance_scale = guidance_scale
|
| 197 |
+
self._attention_kwargs = attention_kwargs
|
| 198 |
+
self._current_timestep = None
|
| 199 |
+
self._interrupt = False
|
| 200 |
+
|
| 201 |
+
device = self._execution_device
|
| 202 |
+
|
| 203 |
+
# 2. Define call parameters
|
| 204 |
+
if prompt is not None and isinstance(prompt, str):
|
| 205 |
+
batch_size = 1
|
| 206 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 207 |
+
batch_size = len(prompt)
|
| 208 |
+
else:
|
| 209 |
+
batch_size = prompt_embeds.shape[0]
|
| 210 |
+
|
| 211 |
+
# 3. Encode input prompt
|
| 212 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 213 |
+
prompt=prompt,
|
| 214 |
+
negative_prompt=negative_prompt,
|
| 215 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 216 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 217 |
+
prompt_embeds=prompt_embeds,
|
| 218 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 219 |
+
max_sequence_length=max_sequence_length,
|
| 220 |
+
device=device,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
transformer_dtype = self.transformer.dtype
|
| 224 |
+
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
| 225 |
+
if negative_prompt_embeds is not None:
|
| 226 |
+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
| 227 |
+
|
| 228 |
+
# 4. Prepare timesteps
|
| 229 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 230 |
+
timesteps = self.scheduler.timesteps
|
| 231 |
+
|
| 232 |
+
# 5. Prepare latent variables
|
| 233 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 234 |
+
latents = self.prepare_latents(
|
| 235 |
+
batch_size * num_videos_per_prompt,
|
| 236 |
+
num_channels_latents,
|
| 237 |
+
height,
|
| 238 |
+
width,
|
| 239 |
+
num_frames,
|
| 240 |
+
torch.float32,
|
| 241 |
+
device,
|
| 242 |
+
generator,
|
| 243 |
+
latents,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
all_latents = [latents]
|
| 247 |
+
all_log_probs = []
|
| 248 |
+
all_kl = []
|
| 249 |
+
|
| 250 |
+
# 6. Denoising loop
|
| 251 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 252 |
+
self._num_timesteps = len(timesteps)
|
| 253 |
+
# print(timesteps)
|
| 254 |
+
|
| 255 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 256 |
+
for i, t in enumerate(timesteps):
|
| 257 |
+
if self.interrupt:
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
latents_ori = latents.clone()
|
| 261 |
+
self._current_timestep = t
|
| 262 |
+
latent_model_input = latents.to(transformer_dtype)
|
| 263 |
+
timestep = t.expand(latents.shape[0])
|
| 264 |
+
|
| 265 |
+
noise_pred = self.transformer(
|
| 266 |
+
hidden_states=latent_model_input,
|
| 267 |
+
timestep=timestep,
|
| 268 |
+
encoder_hidden_states=prompt_embeds,
|
| 269 |
+
attention_kwargs=attention_kwargs,
|
| 270 |
+
return_dict=False,
|
| 271 |
+
)[0]
|
| 272 |
+
noise_pred = noise_pred.to(prompt_embeds.dtype)
|
| 273 |
+
|
| 274 |
+
if self.do_classifier_free_guidance:
|
| 275 |
+
noise_uncond = self.transformer(
|
| 276 |
+
hidden_states=latent_model_input,
|
| 277 |
+
timestep=timestep,
|
| 278 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 279 |
+
attention_kwargs=attention_kwargs,
|
| 280 |
+
return_dict=False,
|
| 281 |
+
)[0]
|
| 282 |
+
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
| 283 |
+
|
| 284 |
+
latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(
|
| 285 |
+
self.scheduler,
|
| 286 |
+
noise_pred.float(),
|
| 287 |
+
t.unsqueeze(0),
|
| 288 |
+
latents.float(),
|
| 289 |
+
determistic=determistic,
|
| 290 |
+
return_pixel_log_prob=return_pixel_log_prob
|
| 291 |
+
)
|
| 292 |
+
prev_latents = latents.clone()
|
| 293 |
+
|
| 294 |
+
all_latents.append(latents)
|
| 295 |
+
all_log_probs.append(log_prob)
|
| 296 |
+
|
| 297 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 298 |
+
# latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 299 |
+
|
| 300 |
+
if callback_on_step_end is not None:
|
| 301 |
+
callback_kwargs = {}
|
| 302 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 303 |
+
callback_kwargs[k] = locals()[k]
|
| 304 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 305 |
+
|
| 306 |
+
latents = callback_outputs.pop("latents", latents)
|
| 307 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 308 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 309 |
+
|
| 310 |
+
# use kl_reward & is sampling process
|
| 311 |
+
if kl_reward>0 and not determistic:
|
| 312 |
+
latent_model_input = torch.cat([latents_ori] * 2) if self.do_classifier_free_guidance else latents_ori
|
| 313 |
+
with self.transformer.disable_adapter():
|
| 314 |
+
noise_pred = self.transformer(
|
| 315 |
+
hidden_states=latent_model_input,
|
| 316 |
+
timestep=timestep,
|
| 317 |
+
encoder_hidden_states=prompt_embeds,
|
| 318 |
+
attention_kwargs=attention_kwargs,
|
| 319 |
+
return_dict=False,
|
| 320 |
+
)[0]
|
| 321 |
+
noise_pred = noise_pred.to(prompt_embeds.dtype)
|
| 322 |
+
# perform guidance
|
| 323 |
+
if self.do_classifier_free_guidance:
|
| 324 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 325 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 326 |
+
|
| 327 |
+
_, ref_log_prob, ref_prev_latents_mean, ref_std_dev_t = sde_step_with_logprob(
|
| 328 |
+
self.scheduler,
|
| 329 |
+
noise_pred.float(),
|
| 330 |
+
t.unsqueeze(0),
|
| 331 |
+
latents_ori.float(),
|
| 332 |
+
prev_sample=prev_latents.float(),
|
| 333 |
+
determistic=determistic,
|
| 334 |
+
)
|
| 335 |
+
assert std_dev_t == ref_std_dev_t
|
| 336 |
+
kl = (prev_latents_mean - ref_prev_latents_mean)**2 / (2 * std_dev_t**2)
|
| 337 |
+
kl = kl.mean(dim=tuple(range(1, kl.ndim)))
|
| 338 |
+
all_kl.append(kl)
|
| 339 |
+
else:
|
| 340 |
+
# no kl reward, we do not need to compute, just put a pre-position value, kl will be 0
|
| 341 |
+
all_kl.append(torch.zeros(len(latents), device=latents.device))
|
| 342 |
+
|
| 343 |
+
# call the callback, if provided
|
| 344 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 345 |
+
progress_bar.update()
|
| 346 |
+
|
| 347 |
+
# if XLA_AVAILABLE:
|
| 348 |
+
# xm.mark_step()
|
| 349 |
+
|
| 350 |
+
self._current_timestep = None
|
| 351 |
+
|
| 352 |
+
if not output_type == "latent":
|
| 353 |
+
latents = latents.to(self.vae.dtype)
|
| 354 |
+
latents_mean = (
|
| 355 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 356 |
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 357 |
+
.to(latents.device, latents.dtype)
|
| 358 |
+
)
|
| 359 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 360 |
+
latents.device, latents.dtype
|
| 361 |
+
)
|
| 362 |
+
latents = latents / latents_std + latents_mean
|
| 363 |
+
video = self.vae.decode(latents, return_dict=False)[0]
|
| 364 |
+
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
| 365 |
+
else:
|
| 366 |
+
video = latents
|
| 367 |
+
|
| 368 |
+
self.maybe_free_model_hooks()
|
| 369 |
+
|
| 370 |
+
if not return_dict:
|
| 371 |
+
return (video, all_latents, all_log_probs, all_kl)
|
| 372 |
+
|
| 373 |
+
return WanPipelineOutput(frames=video), all_latents, all_log_probs, all_kl
|
adv_grpo/diffusers_patch/wan_prompt_embedding.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
def _get_t5_prompt_embeds(
|
| 5 |
+
text_encoder,
|
| 6 |
+
tokenizer,
|
| 7 |
+
prompt: Union[str, List[str]] = None,
|
| 8 |
+
max_sequence_length: int = 226,
|
| 9 |
+
num_videos_per_prompt: int = 1,
|
| 10 |
+
device: Optional[torch.device] = None,
|
| 11 |
+
dtype: Optional[torch.dtype] = None,
|
| 12 |
+
):
|
| 13 |
+
|
| 14 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 15 |
+
batch_size = len(prompt)
|
| 16 |
+
|
| 17 |
+
text_inputs = tokenizer(
|
| 18 |
+
prompt,
|
| 19 |
+
padding="max_length",
|
| 20 |
+
max_length=max_sequence_length,
|
| 21 |
+
truncation=True,
|
| 22 |
+
add_special_tokens=True,
|
| 23 |
+
return_attention_mask=True,
|
| 24 |
+
return_tensors="pt",
|
| 25 |
+
)
|
| 26 |
+
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
| 27 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 28 |
+
|
| 29 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
| 30 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 31 |
+
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 32 |
+
prompt_embeds = torch.stack(
|
| 33 |
+
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 37 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 38 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 39 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 40 |
+
|
| 41 |
+
return prompt_embeds
|
| 42 |
+
|
| 43 |
+
def encode_prompt(
|
| 44 |
+
text_encoder,
|
| 45 |
+
tokenizer,
|
| 46 |
+
prompt: Union[str, List[str]],
|
| 47 |
+
max_sequence_length: int = 226,
|
| 48 |
+
num_videos_per_prompt: int = 1,
|
| 49 |
+
device: Optional[torch.device] = None,
|
| 50 |
+
dtype: Optional[torch.dtype] = None,
|
| 51 |
+
):
|
| 52 |
+
r"""
|
| 53 |
+
Encodes the prompt into text encoder hidden states.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 57 |
+
prompt to be encoded
|
| 58 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 59 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 60 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 61 |
+
less than `1`).
|
| 62 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 63 |
+
Whether to use classifier free guidance or not.
|
| 64 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 65 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 66 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 67 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 68 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 69 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 70 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 71 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 72 |
+
argument.
|
| 73 |
+
device: (`torch.device`, *optional*):
|
| 74 |
+
torch device
|
| 75 |
+
dtype: (`torch.dtype`, *optional*):
|
| 76 |
+
torch dtype
|
| 77 |
+
"""
|
| 78 |
+
device = text_encoder[0].device
|
| 79 |
+
dtype = text_encoder[0].dtype
|
| 80 |
+
|
| 81 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 82 |
+
if prompt is not None:
|
| 83 |
+
batch_size = len(prompt)
|
| 84 |
+
else:
|
| 85 |
+
batch_size = prompt_embeds.shape[0]
|
| 86 |
+
|
| 87 |
+
prompt_embeds = _get_t5_prompt_embeds(
|
| 88 |
+
text_encoder=text_encoder[0],
|
| 89 |
+
tokenizer=tokenizer[0],
|
| 90 |
+
prompt=prompt,
|
| 91 |
+
max_sequence_length=max_sequence_length,
|
| 92 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 93 |
+
device=device,
|
| 94 |
+
dtype=dtype,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return prompt_embeds
|
adv_grpo/ema.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from another repo, but I can't remember exactly which one.
|
| 2 |
+
|
| 3 |
+
from collections.abc import Iterable
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class EMAModuleWrapper:
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
parameters: Iterable[torch.nn.Parameter],
|
| 12 |
+
decay: float = 0.9999,
|
| 13 |
+
update_step_interval: int = 1,
|
| 14 |
+
device: torch.device | None = None,
|
| 15 |
+
):
|
| 16 |
+
parameters = list(parameters)
|
| 17 |
+
self.ema_parameters = [p.clone().detach().to(device) for p in parameters]
|
| 18 |
+
|
| 19 |
+
self.temp_stored_parameters = None
|
| 20 |
+
|
| 21 |
+
self.decay = decay
|
| 22 |
+
self.update_step_interval = update_step_interval
|
| 23 |
+
self.device = device
|
| 24 |
+
|
| 25 |
+
# TODO: add an automatic decay calculation based on this formula:
|
| 26 |
+
# The impact of the last n steps can be calculated as:
|
| 27 |
+
# impact = 1-(decay^n)
|
| 28 |
+
# The number of steps needed to reach a specific impact is:
|
| 29 |
+
# n = log_decay(1-impact)
|
| 30 |
+
# The decay needed to reach a specific impact after n steps is:
|
| 31 |
+
# decay = (1-impact)^(1/n)
|
| 32 |
+
|
| 33 |
+
def get_current_decay(self, optimization_step) -> float:
|
| 34 |
+
return min(
|
| 35 |
+
(1 + optimization_step) / (10 + optimization_step),
|
| 36 |
+
self.decay
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def step(self, parameters: Iterable[torch.nn.Parameter], optimization_step):
|
| 41 |
+
parameters = list(parameters)
|
| 42 |
+
|
| 43 |
+
one_minus_decay = 1 - self.get_current_decay(optimization_step)
|
| 44 |
+
|
| 45 |
+
if (optimization_step + 1) % self.update_step_interval == 0:
|
| 46 |
+
for ema_parameter, parameter in zip(self.ema_parameters, parameters, strict=True):
|
| 47 |
+
if parameter.requires_grad:
|
| 48 |
+
if ema_parameter.device == parameter.device:
|
| 49 |
+
ema_parameter.add_(one_minus_decay * (parameter - ema_parameter))
|
| 50 |
+
else:
|
| 51 |
+
# in place calculations to save memory
|
| 52 |
+
parameter_copy = parameter.detach().to(ema_parameter.device)
|
| 53 |
+
parameter_copy.sub_(ema_parameter)
|
| 54 |
+
parameter_copy.mul_(one_minus_decay)
|
| 55 |
+
ema_parameter.add_(parameter_copy)
|
| 56 |
+
del parameter_copy
|
| 57 |
+
|
| 58 |
+
def to(self, device: torch.device = None, dtype: torch.dtype = None) -> None:
|
| 59 |
+
self.device = device
|
| 60 |
+
self.ema_parameters = [
|
| 61 |
+
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
|
| 62 |
+
for p in self.ema_parameters
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
def copy_ema_to(self, parameters: Iterable[torch.nn.Parameter], store_temp: bool = True) -> None:
|
| 66 |
+
if store_temp:
|
| 67 |
+
self.temp_stored_parameters = [parameter.detach().cpu() for parameter in parameters]
|
| 68 |
+
|
| 69 |
+
parameters = list(parameters)
|
| 70 |
+
for ema_parameter, parameter in zip(self.ema_parameters, parameters, strict=True):
|
| 71 |
+
parameter.data.copy_(ema_parameter.to(parameter.device).data)
|
| 72 |
+
|
| 73 |
+
def copy_temp_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 74 |
+
for temp_parameter, parameter in zip(self.temp_stored_parameters, parameters, strict=True):
|
| 75 |
+
parameter.data.copy_(temp_parameter.data)
|
| 76 |
+
|
| 77 |
+
self.temp_stored_parameters = None
|
| 78 |
+
|
| 79 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 80 |
+
self.decay = self.decay if self.decay else state_dict.get("decay", self.decay)
|
| 81 |
+
self.ema_parameters = state_dict.get("ema_parameters")
|
| 82 |
+
self.to(self.device)
|
| 83 |
+
|
| 84 |
+
def state_dict(self) -> dict:
|
| 85 |
+
return {
|
| 86 |
+
"decay": self.decay,
|
| 87 |
+
"ema_parameters": self.ema_parameters,
|
| 88 |
+
}
|
adv_grpo/imagereward_scorer.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoProcessor, AutoModel
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
import ImageReward as RM
|
| 5 |
+
|
| 6 |
+
class ImageRewardScorer(torch.nn.Module):
|
| 7 |
+
def __init__(self, device="cuda", dtype=torch.float32):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.model_path = "ImageReward-v1.0"
|
| 10 |
+
self.device = device
|
| 11 |
+
self.dtype = dtype
|
| 12 |
+
self.model = RM.load(self.model_path, device=device).eval().to(dtype=dtype)
|
| 13 |
+
self.model.requires_grad_(False)
|
| 14 |
+
|
| 15 |
+
@torch.no_grad()
|
| 16 |
+
def __call__(self, prompts, images):
|
| 17 |
+
rewards = []
|
| 18 |
+
for prompt,image in zip(prompts, images):
|
| 19 |
+
_, reward = self.model.inference_rank(prompt, [image])
|
| 20 |
+
rewards.append(reward)
|
| 21 |
+
return rewards
|
| 22 |
+
|
| 23 |
+
# Usage example
|
| 24 |
+
def main():
|
| 25 |
+
scorer = ImageRewardScorer(
|
| 26 |
+
device="cuda",
|
| 27 |
+
dtype=torch.float32
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
images=[
|
| 31 |
+
"astronaut.jpg",
|
| 32 |
+
]
|
| 33 |
+
pil_images = [Image.open(img) for img in images]
|
| 34 |
+
prompts=[
|
| 35 |
+
'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
|
| 36 |
+
]
|
| 37 |
+
print(scorer(prompts, pil_images))
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
main()
|
adv_grpo/inflated_layers.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import Literal
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from torch.nn import ConvTranspose2d, ConvTranspose3d
|
| 6 |
+
|
| 7 |
+
from flow_grpo.inflated_lib import (
|
| 8 |
+
MemoryState,
|
| 9 |
+
extend_head,
|
| 10 |
+
inflate_bias,
|
| 11 |
+
inflate_distribution_bias,
|
| 12 |
+
inflate_distribution_weight,
|
| 13 |
+
inflate_weight,
|
| 14 |
+
modify_state_dict,
|
| 15 |
+
)
|
| 16 |
+
from flow_grpo.conv_gradfix import GradFixConv2d, GradFixConv3d
|
| 17 |
+
|
| 18 |
+
VERBOSE = False
|
| 19 |
+
|
| 20 |
+
_inflation_mode_t = (Literal["none", "flatten", "partial_flatten", "pad", "tile"],)
|
| 21 |
+
_direction_t = Literal["", "out", "in"]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class InflatedCausalConv3d(GradFixConv3d):
|
| 25 |
+
"""
|
| 26 |
+
Note:
|
| 27 |
+
To align the behavior of pretrained 2D models,
|
| 28 |
+
if you compose a video clip from a single image by:
|
| 29 |
+
- duplicating: set shape_norm = True
|
| 30 |
+
- padding zeros: set shape_norm = False
|
| 31 |
+
to avoid gaps in the beginning of training process.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self, *args, inflation_mode: _inflation_mode_t, shape_norm: bool = True, **kwargs
|
| 36 |
+
):
|
| 37 |
+
self.shape_norm = shape_norm
|
| 38 |
+
self.inflation_mode = inflation_mode
|
| 39 |
+
self.padding_bank = None
|
| 40 |
+
super().__init__(*args, **kwargs)
|
| 41 |
+
self.temporal_padding = self.padding[0]
|
| 42 |
+
self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal.
|
| 43 |
+
|
| 44 |
+
def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor:
|
| 45 |
+
bank_size = self.stride[0] - self.kernel_size[0]
|
| 46 |
+
padding_bank = (
|
| 47 |
+
input[:, :, bank_size:].detach()
|
| 48 |
+
if (bank_size != 0 and memory_state != MemoryState.DISABLED)
|
| 49 |
+
else None
|
| 50 |
+
)
|
| 51 |
+
if (self.padding_bank is not None) and (memory_state == MemoryState.ACTIVE):
|
| 52 |
+
input = extend_head(input, memory=self.padding_bank)
|
| 53 |
+
else:
|
| 54 |
+
input = extend_head(input, times=self.temporal_padding * 2)
|
| 55 |
+
if memory_state != MemoryState.DISABLED and not self.training:
|
| 56 |
+
self.padding_bank = padding_bank
|
| 57 |
+
return super().forward(input)
|
| 58 |
+
|
| 59 |
+
def _load_from_state_dict(
|
| 60 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 61 |
+
):
|
| 62 |
+
if self.inflation_mode == "none":
|
| 63 |
+
super()._load_from_state_dict(
|
| 64 |
+
state_dict,
|
| 65 |
+
prefix,
|
| 66 |
+
local_metadata,
|
| 67 |
+
strict,
|
| 68 |
+
missing_keys,
|
| 69 |
+
unexpected_keys,
|
| 70 |
+
error_msgs,
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
# NOTE: need to switch off strict
|
| 74 |
+
super()._load_from_state_dict(
|
| 75 |
+
modify_state_dict(
|
| 76 |
+
self,
|
| 77 |
+
state_dict,
|
| 78 |
+
prefix,
|
| 79 |
+
verbose=VERBOSE,
|
| 80 |
+
inflate_weight_fn=partial(inflate_weight, position="tail"),
|
| 81 |
+
inflate_bias_fn=partial(inflate_bias, position="tail"),
|
| 82 |
+
),
|
| 83 |
+
prefix,
|
| 84 |
+
local_metadata,
|
| 85 |
+
False,
|
| 86 |
+
missing_keys,
|
| 87 |
+
unexpected_keys,
|
| 88 |
+
error_msgs,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class InflatedDistributionCausalConv3d(GradFixConv3d):
|
| 93 |
+
"""
|
| 94 |
+
Note:
|
| 95 |
+
Direction:
|
| 96 |
+
- out: this layer generates mean/std of some distribution;
|
| 97 |
+
- in: this layer takes tensors sampled from output of `out` layer as input.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
*args,
|
| 103 |
+
direction: _direction_t,
|
| 104 |
+
inflation_mode: _inflation_mode_t,
|
| 105 |
+
shape_norm: bool = True,
|
| 106 |
+
**kwargs,
|
| 107 |
+
):
|
| 108 |
+
self.shape_norm = shape_norm
|
| 109 |
+
self.inflation_mode = inflation_mode
|
| 110 |
+
self.direction = direction
|
| 111 |
+
self.padding_bank = None
|
| 112 |
+
super().__init__(*args, **kwargs)
|
| 113 |
+
self.temporal_padding = self.padding[0]
|
| 114 |
+
self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal.
|
| 115 |
+
|
| 116 |
+
def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor:
|
| 117 |
+
bank_size = self.stride[0] - self.kernel_size[0]
|
| 118 |
+
padding_bank = (
|
| 119 |
+
input[:, :, bank_size:].detach()
|
| 120 |
+
if (bank_size != 0 and memory_state != MemoryState.DISABLED)
|
| 121 |
+
else None
|
| 122 |
+
)
|
| 123 |
+
if (self.padding_bank is not None) and (memory_state == MemoryState.ACTIVE):
|
| 124 |
+
input = extend_head(input, memory=self.padding_bank)
|
| 125 |
+
else:
|
| 126 |
+
input = extend_head(input, times=self.temporal_padding * 2)
|
| 127 |
+
if memory_state != MemoryState.DISABLED and not self.training:
|
| 128 |
+
self.padding_bank = padding_bank
|
| 129 |
+
return super().forward(input)
|
| 130 |
+
|
| 131 |
+
def _load_from_state_dict(
|
| 132 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 133 |
+
):
|
| 134 |
+
if self.inflation_mode == "none":
|
| 135 |
+
super()._load_from_state_dict(
|
| 136 |
+
state_dict,
|
| 137 |
+
prefix,
|
| 138 |
+
local_metadata,
|
| 139 |
+
strict,
|
| 140 |
+
missing_keys,
|
| 141 |
+
unexpected_keys,
|
| 142 |
+
error_msgs,
|
| 143 |
+
)
|
| 144 |
+
else:
|
| 145 |
+
super()._load_from_state_dict(
|
| 146 |
+
modify_state_dict(
|
| 147 |
+
self,
|
| 148 |
+
state_dict,
|
| 149 |
+
prefix,
|
| 150 |
+
verbose=VERBOSE,
|
| 151 |
+
inflate_weight_fn=partial(
|
| 152 |
+
inflate_distribution_weight, direction=self.direction, position="tail"
|
| 153 |
+
),
|
| 154 |
+
inflate_bias_fn=partial(
|
| 155 |
+
inflate_distribution_bias, direction=self.direction, position="tail"
|
| 156 |
+
),
|
| 157 |
+
),
|
| 158 |
+
prefix,
|
| 159 |
+
local_metadata,
|
| 160 |
+
False,
|
| 161 |
+
missing_keys,
|
| 162 |
+
unexpected_keys,
|
| 163 |
+
error_msgs,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class InflatedConvTranspose3d(ConvTranspose3d):
|
| 168 |
+
# Note: It's not a causal one.
|
| 169 |
+
def __init__(
|
| 170 |
+
self, *args, inflation_mode: _inflation_mode_t, shape_norm: bool = True, **kwargs
|
| 171 |
+
):
|
| 172 |
+
self.shape_norm = shape_norm
|
| 173 |
+
self.inflation_mode = inflation_mode
|
| 174 |
+
super().__init__(*args, **kwargs)
|
| 175 |
+
|
| 176 |
+
def _load_from_state_dict(
|
| 177 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 178 |
+
):
|
| 179 |
+
if self.inflation_mode == "none":
|
| 180 |
+
super()._load_from_state_dict(
|
| 181 |
+
state_dict,
|
| 182 |
+
prefix,
|
| 183 |
+
local_metadata,
|
| 184 |
+
strict,
|
| 185 |
+
missing_keys,
|
| 186 |
+
unexpected_keys,
|
| 187 |
+
error_msgs,
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
# NOTE: need to switch off strict
|
| 191 |
+
super()._load_from_state_dict(
|
| 192 |
+
modify_state_dict(
|
| 193 |
+
self,
|
| 194 |
+
state_dict,
|
| 195 |
+
prefix,
|
| 196 |
+
verbose=VERBOSE,
|
| 197 |
+
inflate_weight_fn=partial(inflate_weight, position="center"),
|
| 198 |
+
inflate_bias_fn=partial(inflate_bias, position="center"),
|
| 199 |
+
),
|
| 200 |
+
prefix,
|
| 201 |
+
local_metadata,
|
| 202 |
+
False,
|
| 203 |
+
missing_keys,
|
| 204 |
+
unexpected_keys,
|
| 205 |
+
error_msgs,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class FlattenedConvTranspose3d(ConvTranspose2d):
|
| 210 |
+
def forward(self, input: Tensor, **kwargs) -> Tensor:
|
| 211 |
+
output = rearrange(input, "b c f h w -> (b f) c h w")
|
| 212 |
+
output = super().forward(output)
|
| 213 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=input.size(2))
|
| 214 |
+
return output
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class FlattenedConv3d(GradFixConv2d):
|
| 218 |
+
def forward(self, input: Tensor, **kwargs) -> Tensor:
|
| 219 |
+
output = rearrange(input, "b c f h w -> (b f) c h w")
|
| 220 |
+
output = super().forward(output)
|
| 221 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=input.size(2))
|
| 222 |
+
return output
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def init_causal_conv3d(
|
| 226 |
+
*args,
|
| 227 |
+
inflation_mode: _inflation_mode_t,
|
| 228 |
+
direction: _direction_t = "",
|
| 229 |
+
partial_switch: bool = False,
|
| 230 |
+
**kwargs,
|
| 231 |
+
):
|
| 232 |
+
"""
|
| 233 |
+
Initialize a Causal-3D convolution layer.
|
| 234 |
+
Parameters:
|
| 235 |
+
inflation_mode: Listed as below. It's compatible with all the 3D-VAE checkpoints we have.
|
| 236 |
+
- none: No inflation will be conducted.
|
| 237 |
+
The loading logic of state dict will fall back to default.
|
| 238 |
+
- flatten: It will produce a `fake` 3D layer,
|
| 239 |
+
which simply squeeze the axis of batch size and depth together,
|
| 240 |
+
and then conduct 2D convolution.
|
| 241 |
+
- partial_flatten:
|
| 242 |
+
- layers with `partial_switch` on: using `none` mode.
|
| 243 |
+
- layers with `partial_switch` off: using `flatten` mode.
|
| 244 |
+
- pad / tile: Refer to the definition of `InflatedCausalConv3d`.
|
| 245 |
+
direction:
|
| 246 |
+
- empty string: Ordinary causal convolution layer.
|
| 247 |
+
- out / in: Refer to the definition of `InflatedDistributionCausalConv3d`.
|
| 248 |
+
partial_switch: Only works when `inflation_mode` is `partial_flatten`.
|
| 249 |
+
"""
|
| 250 |
+
stride = kwargs.get("stride", args[3] if len(args) > 3 else None)
|
| 251 |
+
padding = kwargs.get("padding", args[4] if len(args) > 4 else None)
|
| 252 |
+
if "flatten" in inflation_mode:
|
| 253 |
+
if (
|
| 254 |
+
(
|
| 255 |
+
(not stride)
|
| 256 |
+
or isinstance(stride, int)
|
| 257 |
+
or (isinstance(stride, list or tuple) and len(stride) < 3)
|
| 258 |
+
) # if the config of stride can be used for 2D conv
|
| 259 |
+
and (
|
| 260 |
+
(not padding)
|
| 261 |
+
or isinstance(padding, int)
|
| 262 |
+
or (isinstance(padding, list or tuple) and len(padding) < 3)
|
| 263 |
+
) # if the config of padding can be used for 2D conv
|
| 264 |
+
and (("partial" not in inflation_mode) or (not partial_switch))
|
| 265 |
+
# if it's fully-flatten mode, or with `partial_switch` off
|
| 266 |
+
):
|
| 267 |
+
return FlattenedConv3d(*args, **kwargs)
|
| 268 |
+
else:
|
| 269 |
+
return InflatedCausalConv3d(*args, inflation_mode="none", **kwargs)
|
| 270 |
+
# Force-override
|
| 271 |
+
else:
|
| 272 |
+
if direction:
|
| 273 |
+
return InflatedDistributionCausalConv3d(
|
| 274 |
+
*args, direction=direction, inflation_mode=inflation_mode, **kwargs
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
return InflatedCausalConv3d(*args, inflation_mode=inflation_mode, **kwargs)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def init_transposed_conv3d(
|
| 281 |
+
*args, inflation_mode: _inflation_mode_t, partial_switch: bool = False, **kwargs
|
| 282 |
+
):
|
| 283 |
+
stride = kwargs.get("stride", args[3] if len(args) > 3 else None)
|
| 284 |
+
padding = kwargs.get("padding", args[4] if len(args) > 4 else None)
|
| 285 |
+
if "flatten" in inflation_mode:
|
| 286 |
+
if (
|
| 287 |
+
(
|
| 288 |
+
(not stride)
|
| 289 |
+
or isinstance(stride, int)
|
| 290 |
+
or (isinstance(stride, list or tuple) and len(stride) < 3)
|
| 291 |
+
)
|
| 292 |
+
and (
|
| 293 |
+
(not padding)
|
| 294 |
+
or isinstance(padding, int)
|
| 295 |
+
or (isinstance(padding, list or tuple) and len(padding) < 3)
|
| 296 |
+
)
|
| 297 |
+
or (("partial" in inflation_mode) and not partial_switch)
|
| 298 |
+
):
|
| 299 |
+
return FlattenedConvTranspose3d(*args, **kwargs)
|
| 300 |
+
else:
|
| 301 |
+
return InflatedConvTranspose3d(
|
| 302 |
+
*args, inflation_mode="none", **kwargs
|
| 303 |
+
) # Force-override
|
| 304 |
+
else:
|
| 305 |
+
return InflatedConvTranspose3d(*args, inflation_mode=inflation_mode, **kwargs)
|
adv_grpo/inflated_lib.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from typing import Optional
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from diffusers.models.attention_processor import SpatialNorm
|
| 7 |
+
from diffusers.models.normalization import RMSNorm
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
|
| 11 |
+
# from common.logger import get_logger
|
| 12 |
+
|
| 13 |
+
# logger = get_logger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MemoryState(Enum):
|
| 17 |
+
"""
|
| 18 |
+
State[Disabled]: No memory bank will be enabled.
|
| 19 |
+
State[Initializing]: The model is handling the first clip,
|
| 20 |
+
need to reset / initialize the memory bank.
|
| 21 |
+
State[Active]: There has been some data in the memory bank.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
DISABLED = 0
|
| 25 |
+
INITIALIZING = 1
|
| 26 |
+
ACTIVE = 2
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def norm_wrapper(
|
| 30 |
+
norm_layer: nn.Module,
|
| 31 |
+
x: torch.Tensor,
|
| 32 |
+
y: Optional[torch.Tensor] = None,
|
| 33 |
+
keep_causal: bool = False,
|
| 34 |
+
) -> torch.Tensor:
|
| 35 |
+
if isinstance(norm_layer, (nn.LayerNorm, RMSNorm)):
|
| 36 |
+
if x.ndim == 4:
|
| 37 |
+
x = rearrange(x, "b c h w -> b h w c")
|
| 38 |
+
x = norm_layer(x)
|
| 39 |
+
x = rearrange(x, "b h w c -> b c h w")
|
| 40 |
+
return x
|
| 41 |
+
if x.ndim == 5:
|
| 42 |
+
x = rearrange(x, "b c t h w -> b t h w c")
|
| 43 |
+
x = norm_layer(x)
|
| 44 |
+
x = rearrange(x, "b t h w c -> b c t h w")
|
| 45 |
+
return x
|
| 46 |
+
if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
|
| 47 |
+
if x.ndim <= 4 or (not keep_causal and not isinstance(norm_layer, nn.BatchNorm2d)):
|
| 48 |
+
return norm_layer(x)
|
| 49 |
+
if x.ndim == 5:
|
| 50 |
+
t = x.size(2)
|
| 51 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 52 |
+
x = norm_layer(x)
|
| 53 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
| 54 |
+
return x
|
| 55 |
+
if isinstance(norm_layer, SpatialNorm):
|
| 56 |
+
t = -1
|
| 57 |
+
if x.ndim == 5:
|
| 58 |
+
t = x.size(2)
|
| 59 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 60 |
+
if y.ndim == 5:
|
| 61 |
+
y = rearrange(y, "b c t h w -> (b t) c h w")
|
| 62 |
+
if x.ndim != 4 or y.ndim != 4:
|
| 63 |
+
raise NotImplementedError
|
| 64 |
+
x = norm_layer(x, y)
|
| 65 |
+
if t != -1:
|
| 66 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
| 67 |
+
return x
|
| 68 |
+
raise NotImplementedError
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def remove_head(tensor: Tensor, times: int = 1) -> Tensor:
|
| 72 |
+
"""
|
| 73 |
+
Remove duplicated first frame features in the up-sampling process.
|
| 74 |
+
"""
|
| 75 |
+
if times == 0:
|
| 76 |
+
return tensor
|
| 77 |
+
return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def extend_head(
|
| 81 |
+
tensor: Tensor, times: Optional[int] = 2, memory: Optional[Tensor] = None
|
| 82 |
+
) -> Tensor:
|
| 83 |
+
"""
|
| 84 |
+
When memory is None:
|
| 85 |
+
- Duplicate first frame features in the down-sampling process.
|
| 86 |
+
When memory is not None:
|
| 87 |
+
- Concatenate memory features with the input features to keep temporal consistency.
|
| 88 |
+
"""
|
| 89 |
+
if times == 0:
|
| 90 |
+
return tensor
|
| 91 |
+
if memory is not None:
|
| 92 |
+
return torch.cat((memory.to(tensor), tensor), dim=2)
|
| 93 |
+
else:
|
| 94 |
+
tile_repeat = np.ones(tensor.ndim).astype(int)
|
| 95 |
+
tile_repeat[2] = times
|
| 96 |
+
return torch.cat(tensors=(torch.tile(tensor[:, :, :1], list(tile_repeat)), tensor), dim=2)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def fill_weight_in_depth(weight: torch.Tensor, source: torch.Tensor, position: str):
|
| 100 |
+
"""
|
| 101 |
+
Inflate a 2D convolution weight matrix to a 3D one by padding zeros in the channel of depth.
|
| 102 |
+
Parameters:
|
| 103 |
+
weight: The weight parameters of 3D conv kernel to be initialized.
|
| 104 |
+
source: The weight parameters of 2D conv kernel to be inflated.
|
| 105 |
+
position: Where to insert the 2D weights, can be chosen from
|
| 106 |
+
- tail: Pad zeros in the front of the 2D kernel. Used for casual inflation.
|
| 107 |
+
- center: Pad zeros around the 2D kernel. Used for normal inflation.
|
| 108 |
+
"""
|
| 109 |
+
assert position in ["tail", "center"], "Unsupported fill-in position for weight inflation."
|
| 110 |
+
depth = weight.size(2)
|
| 111 |
+
weight.fill_(0.0)
|
| 112 |
+
if position == "center":
|
| 113 |
+
if depth % 2 == 1:
|
| 114 |
+
weight[:, :, depth // 2].copy_(source.squeeze(2))
|
| 115 |
+
else:
|
| 116 |
+
weight[:, :, depth // 2].copy_(source.squeeze(2) / 2.0)
|
| 117 |
+
weight[:, :, depth // 2 - 1].copy_(source.squeeze(2) / 2.0)
|
| 118 |
+
else:
|
| 119 |
+
if depth % 2 == 1:
|
| 120 |
+
weight[:, :, -1].copy_(source.squeeze(2))
|
| 121 |
+
else:
|
| 122 |
+
weight[:, :, -1].copy_(source.squeeze(2) / 2.0)
|
| 123 |
+
weight[:, :, -2].copy_(source.squeeze(2) / 2.0)
|
| 124 |
+
return weight
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def inflate_weight(
|
| 128 |
+
weight_2d: torch.Tensor,
|
| 129 |
+
weight_3d: torch.Tensor,
|
| 130 |
+
shape_norm: bool,
|
| 131 |
+
name: str,
|
| 132 |
+
inflation_mode: str,
|
| 133 |
+
position: str,
|
| 134 |
+
verbose: bool = True,
|
| 135 |
+
):
|
| 136 |
+
"""
|
| 137 |
+
Inflate a 2D convolution weight matrix to a 3D one.
|
| 138 |
+
Parameters:
|
| 139 |
+
weight_2d: The weight matrix of 2D conv to be inflated.
|
| 140 |
+
weight_3d: The weight matrix of 3D conv to be initialized.
|
| 141 |
+
inflation_mode: the mode of inflation
|
| 142 |
+
- pad: pad zeros around 2D kernel.
|
| 143 |
+
- tile: tile 2D kernel along the depth axis.
|
| 144 |
+
|
| 145 |
+
shape_norm: Whether to scale the parameters of 2D kernel so that the untrained
|
| 146 |
+
inflated model behaves exactly the same as the original 2D model
|
| 147 |
+
in the reconstruction of image and video. recommend to switch it on.
|
| 148 |
+
|
| 149 |
+
name: The name of inflated module. Only be used in logging.
|
| 150 |
+
position: Refer to the doc of `fill_weight_in_depth`.
|
| 151 |
+
Only works when `inflation_mode` is `pad`.
|
| 152 |
+
verbose: Whether to log information about inflation.
|
| 153 |
+
"""
|
| 154 |
+
assert inflation_mode in ["pad", "tile"]
|
| 155 |
+
depth = weight_3d.size(2)
|
| 156 |
+
tgt_out, tgt_in = weight_3d.size()[:2]
|
| 157 |
+
src_out, src_in = weight_2d.size()[:2]
|
| 158 |
+
assert (tgt_out % src_out == 0) and (tgt_in % src_in == 0)
|
| 159 |
+
out_fan, in_fan = tgt_out // src_out, tgt_in // src_in
|
| 160 |
+
depth_factor = 1 if inflation_mode == "pad" else depth
|
| 161 |
+
factor = (depth_factor * math.sqrt(out_fan) * math.sqrt(in_fan)) if shape_norm else 1
|
| 162 |
+
with torch.no_grad():
|
| 163 |
+
channel_inflation = weight_2d.unsqueeze(2).repeat(out_fan, in_fan, 1, 1, 1) / factor
|
| 164 |
+
if inflation_mode == "tile":
|
| 165 |
+
weight_3d.copy_(channel_inflation.repeat(1, 1, depth, 1, 1))
|
| 166 |
+
else:
|
| 167 |
+
weight_3d = fill_weight_in_depth(weight_3d, channel_inflation, position)
|
| 168 |
+
if verbose:
|
| 169 |
+
print(
|
| 170 |
+
f"*** {name}weight {weight_2d.size()} is inflated to {weight_3d.size()} ***"
|
| 171 |
+
)
|
| 172 |
+
return weight_3d
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def inflate_bias(
|
| 176 |
+
bias_2d: torch.Tensor,
|
| 177 |
+
bias_3d: torch.Tensor,
|
| 178 |
+
shape_norm: bool,
|
| 179 |
+
name: str,
|
| 180 |
+
inflation_mode: str,
|
| 181 |
+
position: str,
|
| 182 |
+
verbose: bool = True,
|
| 183 |
+
):
|
| 184 |
+
"""
|
| 185 |
+
Inflate a 2D convolution bias tensor to a 3D one
|
| 186 |
+
Parameters:
|
| 187 |
+
bias_2d: The bias tensor of 2D conv to be inflated.
|
| 188 |
+
bias_3d: The bias tensor of 3D conv to be initialized.
|
| 189 |
+
shape_norm: Refer to `inflate_weight` function.
|
| 190 |
+
name: The name of inflated module. Only be used in logging.
|
| 191 |
+
inflation_mode: Placeholder to align `inflate_weight`.
|
| 192 |
+
position: Placeholder to align `inflate_weight`.
|
| 193 |
+
verbose: Whether to log information about inflation.
|
| 194 |
+
"""
|
| 195 |
+
tgt_ch, src_ch = bias_3d.size(0), bias_2d.size(0)
|
| 196 |
+
assert tgt_ch % src_ch == 0
|
| 197 |
+
fan = tgt_ch // src_ch
|
| 198 |
+
factor = math.sqrt(fan) if shape_norm else 1
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
bias_3d.copy_(bias_2d.repeat(fan) / factor)
|
| 201 |
+
if (tgt_ch != src_ch) and verbose:
|
| 202 |
+
print(f"*** {name}bias {bias_2d.size()} is inflated to {bias_3d.size()} ***")
|
| 203 |
+
return bias_3d
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def inflate_distribution_weight(
|
| 207 |
+
weight_2d: torch.Tensor,
|
| 208 |
+
weight_3d: torch.Tensor,
|
| 209 |
+
shape_norm: bool,
|
| 210 |
+
name: str,
|
| 211 |
+
direction: str,
|
| 212 |
+
inflation_mode: str,
|
| 213 |
+
position: str,
|
| 214 |
+
verbose: bool = True,
|
| 215 |
+
):
|
| 216 |
+
"""
|
| 217 |
+
Inflate a 2D convolution weight matrix to a 3D one.
|
| 218 |
+
Note: Different from `inflate_weight`,
|
| 219 |
+
it's designed for `quant_conv` or `post_quant_conv` layers.
|
| 220 |
+
i.e., a convolution layer used to produce `mean` and `std` of some distribution,
|
| 221 |
+
or its subsequent layer.
|
| 222 |
+
Parameters: Refer to `inflate_weight`.
|
| 223 |
+
direction:
|
| 224 |
+
- out: this layer generates `mean` and `std` of some distribution.
|
| 225 |
+
- in: this layer takes tensors sampled from output of `out` layer as input.
|
| 226 |
+
"""
|
| 227 |
+
assert inflation_mode in ["pad", "tile"]
|
| 228 |
+
depth = weight_3d.size(2)
|
| 229 |
+
tgt_out, tgt_in = weight_3d.size()[:2]
|
| 230 |
+
src_out, src_in = weight_2d.size()[:2]
|
| 231 |
+
assert (tgt_out % src_out == 0) and (tgt_in % src_in == 0)
|
| 232 |
+
out_fan, in_fan = tgt_out // src_out, tgt_in // src_in
|
| 233 |
+
depth_factor = 1 if inflation_mode == "pad" else depth
|
| 234 |
+
if direction == "out":
|
| 235 |
+
factor = (depth_factor * math.sqrt(in_fan)) if shape_norm else 1
|
| 236 |
+
with torch.no_grad():
|
| 237 |
+
in_inflation = weight_2d.unsqueeze(2).repeat(1, in_fan, 1, 1, 1) / factor
|
| 238 |
+
# [src_out, src_in, k_h, k_w] -> [src_out, tgt_in, 1, k_h, k_w]
|
| 239 |
+
out_mean_weight, out_std_weight = torch.chunk(in_inflation, 2, dim=0)
|
| 240 |
+
mean_slice = slice(src_out // 2)
|
| 241 |
+
std_slice = slice(tgt_out // 2, tgt_out // 2 + src_out // 2)
|
| 242 |
+
if inflation_mode == "tile":
|
| 243 |
+
weight_3d[mean_slice] = out_mean_weight
|
| 244 |
+
weight_3d[std_slice] = out_std_weight
|
| 245 |
+
# Other part will be randomly initialized.
|
| 246 |
+
else:
|
| 247 |
+
weight_3d[mean_slice] = fill_weight_in_depth(
|
| 248 |
+
weight_3d[mean_slice], out_mean_weight, position
|
| 249 |
+
)
|
| 250 |
+
weight_3d[std_slice] = fill_weight_in_depth(
|
| 251 |
+
weight_3d[std_slice], out_std_weight, position
|
| 252 |
+
)
|
| 253 |
+
# Other part will be randomly initialized.
|
| 254 |
+
elif direction == "in":
|
| 255 |
+
factor = (depth_factor * math.sqrt(out_fan)) if shape_norm else 1
|
| 256 |
+
with torch.no_grad():
|
| 257 |
+
out_inflation = weight_2d.unsqueeze(2).repeat(out_fan, 1, 1, 1, 1) / factor
|
| 258 |
+
# [src_out, src_in, k_h, k_w] -> [tgt_out, src_in, 1, k_h, k_w]
|
| 259 |
+
if inflation_mode == "tile":
|
| 260 |
+
weight_3d[:, :src_in] = out_inflation
|
| 261 |
+
else:
|
| 262 |
+
weight_3d[:, :src_in] = fill_weight_in_depth(
|
| 263 |
+
weight_3d[:, :src_in], out_inflation, position
|
| 264 |
+
)
|
| 265 |
+
weight_3d[:, src_in:].fill_(0.0)
|
| 266 |
+
else:
|
| 267 |
+
raise NotImplementedError
|
| 268 |
+
if verbose:
|
| 269 |
+
print(
|
| 270 |
+
f"*** [Distribution] {name}weight {weight_2d.size()} "
|
| 271 |
+
f"is inflated to {weight_3d.size()} ***"
|
| 272 |
+
)
|
| 273 |
+
return weight_3d
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def inflate_distribution_bias(
|
| 277 |
+
bias_2d: torch.Tensor,
|
| 278 |
+
bias_3d: torch.Tensor,
|
| 279 |
+
shape_norm: bool,
|
| 280 |
+
name: str,
|
| 281 |
+
direction: str,
|
| 282 |
+
inflation_mode: str,
|
| 283 |
+
position: str,
|
| 284 |
+
verbose: bool = True,
|
| 285 |
+
):
|
| 286 |
+
"""
|
| 287 |
+
The combination of `inflate_distribution_weight` and `inflate_bias`.
|
| 288 |
+
"""
|
| 289 |
+
tgt_ch, src_ch = bias_3d.size(0), bias_2d.size(0)
|
| 290 |
+
assert tgt_ch % src_ch == 0
|
| 291 |
+
if direction == "out":
|
| 292 |
+
with torch.no_grad():
|
| 293 |
+
out_mean_bias, out_std_bias = torch.chunk(bias_2d, 2, dim=0)
|
| 294 |
+
bias_3d[: src_ch // 2] = out_mean_bias
|
| 295 |
+
bias_3d[tgt_ch // 2 : tgt_ch // 2 + src_ch // 2] = out_std_bias
|
| 296 |
+
elif direction == "in":
|
| 297 |
+
with torch.no_grad():
|
| 298 |
+
bias_3d[:src_ch] = bias_2d
|
| 299 |
+
bias_3d[src_ch:].fill_(0.0)
|
| 300 |
+
else:
|
| 301 |
+
raise NotImplementedError
|
| 302 |
+
if verbose:
|
| 303 |
+
print(
|
| 304 |
+
f"*** [Distribution] {name}bias {bias_2d.size()} is inflated to {bias_3d.size()} ***"
|
| 305 |
+
)
|
| 306 |
+
return bias_3d
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def modify_state_dict(
|
| 310 |
+
layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn, verbose=False
|
| 311 |
+
):
|
| 312 |
+
"""
|
| 313 |
+
the main function to inflated 2D parameters to 3D.
|
| 314 |
+
"""
|
| 315 |
+
weight_name = prefix + "weight"
|
| 316 |
+
bias_name = prefix + "bias"
|
| 317 |
+
if weight_name in state_dict:
|
| 318 |
+
weight_2d = state_dict[weight_name]
|
| 319 |
+
if (
|
| 320 |
+
weight_2d.dim() == 4
|
| 321 |
+
): # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w)
|
| 322 |
+
weight_3d = inflate_weight_fn(
|
| 323 |
+
weight_2d=weight_2d,
|
| 324 |
+
weight_3d=layer.weight,
|
| 325 |
+
shape_norm=layer.shape_norm,
|
| 326 |
+
name=prefix,
|
| 327 |
+
verbose=verbose,
|
| 328 |
+
inflation_mode=layer.inflation_mode,
|
| 329 |
+
)
|
| 330 |
+
state_dict[weight_name] = weight_3d
|
| 331 |
+
else:
|
| 332 |
+
return state_dict
|
| 333 |
+
# It's a 3d state dict, should not do inflation on both bias and weight.
|
| 334 |
+
if bias_name in state_dict:
|
| 335 |
+
bias_2d = state_dict[bias_name]
|
| 336 |
+
if bias_2d.dim() == 1: # Assuming the 2D biases are 1D tensors (out_channels,)
|
| 337 |
+
bias_3d = inflate_bias_fn(
|
| 338 |
+
bias_2d=bias_2d,
|
| 339 |
+
bias_3d=layer.bias,
|
| 340 |
+
shape_norm=layer.shape_norm,
|
| 341 |
+
name=prefix,
|
| 342 |
+
verbose=verbose,
|
| 343 |
+
inflation_mode=layer.inflation_mode,
|
| 344 |
+
)
|
| 345 |
+
state_dict[bias_name] = bias_3d
|
| 346 |
+
return state_dict
|
adv_grpo/ocr.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from paddleocr import PaddleOCR
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from Levenshtein import distance
|
| 5 |
+
from typing import List, Union, Tuple
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
class OcrScorer:
|
| 9 |
+
def __init__(self, use_gpu: bool = False):
|
| 10 |
+
"""
|
| 11 |
+
OCR reward calculator
|
| 12 |
+
:param use_gpu: Whether to use GPU acceleration for PaddleOCR
|
| 13 |
+
"""
|
| 14 |
+
self.ocr = PaddleOCR(
|
| 15 |
+
use_angle_cls=False,
|
| 16 |
+
lang="en",
|
| 17 |
+
use_gpu=use_gpu,
|
| 18 |
+
show_log=False # Disable unnecessary log output
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
@torch.no_grad()
|
| 22 |
+
def __call__(self,
|
| 23 |
+
images: Union[List[Image.Image], List[np.ndarray]],
|
| 24 |
+
prompts: List[str]) -> torch.Tensor:
|
| 25 |
+
"""
|
| 26 |
+
Calculate OCR reward
|
| 27 |
+
:param images: List of input images (PIL or numpy format)
|
| 28 |
+
:param prompts: Corresponding target text list
|
| 29 |
+
:return: Reward tensor (CPU)
|
| 30 |
+
"""
|
| 31 |
+
# import pdb; pdb.set_trace()
|
| 32 |
+
prompts = [prompt.split('"')[1] for prompt in prompts]
|
| 33 |
+
rewards = []
|
| 34 |
+
# Ensure input lengths are consistent
|
| 35 |
+
assert len(images) == len(prompts), "Images and prompts must have the same length"
|
| 36 |
+
for img, prompt in zip(images, prompts):
|
| 37 |
+
# Convert image format
|
| 38 |
+
if isinstance(img, Image.Image):
|
| 39 |
+
img = np.array(img)
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
# OCR recognition
|
| 43 |
+
result = self.ocr.ocr(img, cls=False)
|
| 44 |
+
# Extract recognized text (handle possible multi-line results)
|
| 45 |
+
recognized_text = ''.join([res[1][0] if res[1][1] > 0 else '' for res in result[0]]) if result[0] else ''
|
| 46 |
+
|
| 47 |
+
recognized_text = recognized_text.replace(' ', '').lower()
|
| 48 |
+
prompt = prompt.replace(' ', '').lower()
|
| 49 |
+
if prompt in recognized_text:
|
| 50 |
+
dist = 0
|
| 51 |
+
else:
|
| 52 |
+
dist = distance(recognized_text, prompt)
|
| 53 |
+
# import pdb; pdb.set_trace()
|
| 54 |
+
# Recognized many unrelated characters, only add one character penalty
|
| 55 |
+
if dist > len(prompt):
|
| 56 |
+
dist = len(prompt)
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
# Error handling (e.g., OCR parsing failure)
|
| 60 |
+
print(f"OCR processing failed: {str(e)}")
|
| 61 |
+
dist = len(prompt) # Maximum penalty
|
| 62 |
+
reward = 1-dist/(len(prompt))
|
| 63 |
+
rewards.append(reward)
|
| 64 |
+
|
| 65 |
+
return rewards
|
| 66 |
+
|
| 67 |
+
class OcrScorer_video_or_image:
|
| 68 |
+
def __init__(self, use_gpu: bool = False):
|
| 69 |
+
"""
|
| 70 |
+
OCR reward calculator
|
| 71 |
+
:param use_gpu: Whether to use GPU acceleration for PaddleOCR
|
| 72 |
+
"""
|
| 73 |
+
self.ocr = PaddleOCR(
|
| 74 |
+
use_angle_cls=False,
|
| 75 |
+
lang="en",
|
| 76 |
+
use_gpu=use_gpu,
|
| 77 |
+
show_log=False # Disable unnecessary log output
|
| 78 |
+
)
|
| 79 |
+
self.frame_interval = 4
|
| 80 |
+
|
| 81 |
+
@torch.no_grad()
|
| 82 |
+
def __call__(self, images: Union[List[Image.Image], List[np.ndarray]], prompts: List[str]) -> Tuple[List[float], torch.Tensor]:
|
| 83 |
+
"""
|
| 84 |
+
:param images: List of images or videos (each video as np.ndarray of shape [F, H, W, C])
|
| 85 |
+
:param prompts: List of prompts containing target text
|
| 86 |
+
:return: (List of OCR rewards, Tensor of attention regions)
|
| 87 |
+
"""
|
| 88 |
+
prompts = [prompt.split('"')[1] for prompt in prompts]
|
| 89 |
+
assert len(images) == len(prompts), "Mismatch between images and prompts."
|
| 90 |
+
|
| 91 |
+
rewards = []
|
| 92 |
+
for img, prompt in zip(images, prompts):
|
| 93 |
+
prompt = prompt.replace(' ', '').lower()
|
| 94 |
+
frame_rewards = []
|
| 95 |
+
|
| 96 |
+
# Handle video: shape (F, H, W, C)
|
| 97 |
+
if isinstance(img, np.ndarray) and img.ndim == 4:
|
| 98 |
+
sampled_frames = img[::self.frame_interval]
|
| 99 |
+
else:
|
| 100 |
+
sampled_frames = [img]
|
| 101 |
+
|
| 102 |
+
for frame in sampled_frames:
|
| 103 |
+
region = None
|
| 104 |
+
if isinstance(frame, Image.Image):
|
| 105 |
+
frame = np.array(frame)
|
| 106 |
+
try:
|
| 107 |
+
result = self.ocr.ocr(frame, cls=False)
|
| 108 |
+
text = ''.join([res[1][0] if res[1][1] > 0 else '' for res in result[0]]) if result[0] else ''
|
| 109 |
+
text = text.replace(' ', '').lower()
|
| 110 |
+
|
| 111 |
+
dist = distance(text, prompt)
|
| 112 |
+
dist = min(dist, len(prompt))
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
print(f"OCR failed on frame: {e}")
|
| 116 |
+
dist = len(prompt)
|
| 117 |
+
|
| 118 |
+
reward = 1 - dist / len(prompt)
|
| 119 |
+
if reward > 0:
|
| 120 |
+
frame_rewards.append(reward)
|
| 121 |
+
|
| 122 |
+
if frame_rewards:
|
| 123 |
+
rewards.append(sum(frame_rewards) / len(frame_rewards))
|
| 124 |
+
else:
|
| 125 |
+
rewards.append(0.0)
|
| 126 |
+
|
| 127 |
+
return rewards
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
example_image_path = "media_images_eval_images_499_ef42de47b8ec98892954.jpg"
|
| 131 |
+
example_image = Image.open(example_image_path)
|
| 132 |
+
example_prompt = 'New York Skyline with "Hello World" written with fireworks on the sky'
|
| 133 |
+
# Instantiate scorer
|
| 134 |
+
scorer = OcrScorer(use_gpu=False)
|
| 135 |
+
|
| 136 |
+
# Call scorer and print result
|
| 137 |
+
reward = scorer([example_image], [example_prompt])
|
| 138 |
+
print(f"OCR Reward: {reward}")
|
adv_grpo/pick_score_training.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
# ====== 使用你找到的 CLIPCriterion ======
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from torch.nn.modules.loss import _Loss
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader
|
| 10 |
+
import os
|
| 11 |
+
import json
|
| 12 |
+
import torch
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 15 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def evaluate_pickscore(model, processor, json_file, qwen_dir, sd3_dir, device, max_eval=100):
|
| 20 |
+
"""
|
| 21 |
+
简单评估:取前 max_eval 对 Qwen vs SD3 pair,算平均分
|
| 22 |
+
"""
|
| 23 |
+
model.eval()
|
| 24 |
+
if hasattr(model, "module"): # DDP 情况
|
| 25 |
+
model = model.module
|
| 26 |
+
|
| 27 |
+
with open(json_file, "r") as f:
|
| 28 |
+
prompt2img = json.load(f)
|
| 29 |
+
|
| 30 |
+
prompts = list(prompt2img.keys())[:max_eval]
|
| 31 |
+
|
| 32 |
+
qwen_scores, sd3_scores = [], []
|
| 33 |
+
|
| 34 |
+
for prompt in prompts:
|
| 35 |
+
filename = prompt2img[prompt]
|
| 36 |
+
qwen_img_path = os.path.join(qwen_dir, filename)
|
| 37 |
+
sd3_img_path = os.path.join(sd3_dir, filename)
|
| 38 |
+
|
| 39 |
+
if not (os.path.exists(qwen_img_path) and os.path.exists(sd3_img_path)):
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
qwen_img = Image.open(qwen_img_path).convert("RGB")
|
| 43 |
+
sd3_img = Image.open(sd3_img_path).convert("RGB")
|
| 44 |
+
|
| 45 |
+
# 文本 & 图像输入
|
| 46 |
+
text_inputs = processor.tokenizer(
|
| 47 |
+
prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77
|
| 48 |
+
).to(device)
|
| 49 |
+
qwen_inputs = processor(images=qwen_img, return_tensors="pt").to(device)
|
| 50 |
+
sd3_inputs = processor(images=sd3_img, return_tensors="pt").to(device)
|
| 51 |
+
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
text_features = model.get_text_features(**text_inputs)
|
| 54 |
+
qwen_features = model.get_image_features(**qwen_inputs)
|
| 55 |
+
sd3_features = model.get_image_features(**sd3_inputs)
|
| 56 |
+
|
| 57 |
+
# 归一化
|
| 58 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 59 |
+
qwen_features = qwen_features / qwen_features.norm(dim=-1, keepdim=True)
|
| 60 |
+
sd3_features = sd3_features / sd3_features.norm(dim=-1, keepdim=True)
|
| 61 |
+
|
| 62 |
+
# 相似度分数
|
| 63 |
+
logit_scale = model.logit_scale.exp()
|
| 64 |
+
qwen_score = (logit_scale * (text_features @ qwen_features.T)).item()
|
| 65 |
+
sd3_score = (logit_scale * (text_features @ sd3_features.T)).item()
|
| 66 |
+
|
| 67 |
+
qwen_scores.append(qwen_score)
|
| 68 |
+
sd3_scores.append(sd3_score)
|
| 69 |
+
|
| 70 |
+
model.train()
|
| 71 |
+
if len(qwen_scores) > 0:
|
| 72 |
+
print(f"[Eval] Qwen avg={sum(qwen_scores)/len(qwen_scores):.4f} "
|
| 73 |
+
f"| SD3 avg={sum(sd3_scores)/len(sd3_scores):.4f}")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass
|
| 77 |
+
class CLIPCriterionConfig:
|
| 78 |
+
_target_: str = "trainer.criterions.clip_criterion.CLIPCriterion"
|
| 79 |
+
is_distributed: bool = False # 本地先关掉
|
| 80 |
+
label_0_column_name: str = "label_0"
|
| 81 |
+
label_1_column_name: str = "label_1"
|
| 82 |
+
input_ids_column_name: str = "input_ids"
|
| 83 |
+
pixels_0_column_name: str = "pixels_0"
|
| 84 |
+
pixels_1_column_name: str = "pixels_1"
|
| 85 |
+
num_examples_per_prompt_column_name: str = "num_examples_per_prompt"
|
| 86 |
+
in_batch_negatives: bool = False
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class CLIPCriterion(_Loss):
|
| 90 |
+
def __init__(self, cfg: CLIPCriterionConfig):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.cfg = cfg
|
| 93 |
+
|
| 94 |
+
@staticmethod
|
| 95 |
+
def get_features(model, input_ids, pixels_0_values, pixels_1_values):
|
| 96 |
+
# import pdb; pdb.set_trace()
|
| 97 |
+
# if hasattr(model, "module"):
|
| 98 |
+
# model = model.module
|
| 99 |
+
all_pixel_values = torch.cat([pixels_0_values, pixels_1_values], dim=0)
|
| 100 |
+
# text_features, all_image_features = model(text_inputs=input_ids, image_inputs=all_pixel_values)
|
| 101 |
+
text_features = model.get_text_features(input_ids=input_ids)
|
| 102 |
+
all_image_features = model.get_image_features(pixel_values=all_pixel_values)
|
| 103 |
+
all_image_features = all_image_features / all_image_features.norm(dim=-1, keepdim=True)
|
| 104 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 105 |
+
image_0_features, image_1_features = all_image_features.chunk(2, dim=0)
|
| 106 |
+
return image_0_features, image_1_features, text_features
|
| 107 |
+
|
| 108 |
+
@staticmethod
|
| 109 |
+
def gather_features(features):
|
| 110 |
+
all_features = torch.cat(torch.distributed.nn.all_gather(features), dim=0)
|
| 111 |
+
return all_features
|
| 112 |
+
|
| 113 |
+
# def safe_sync(self, msg):
|
| 114 |
+
# torch.cuda.synchronize()
|
| 115 |
+
# print(f"[Rank {dist.get_rank()}] OK at {msg}")
|
| 116 |
+
|
| 117 |
+
def calc_loss(
|
| 118 |
+
self,
|
| 119 |
+
text_features,
|
| 120 |
+
image_0_features,
|
| 121 |
+
image_1_features,
|
| 122 |
+
logit_scale,
|
| 123 |
+
label_0,
|
| 124 |
+
label_1,
|
| 125 |
+
num_examples_per_prompt,
|
| 126 |
+
*args,
|
| 127 |
+
**kwargs
|
| 128 |
+
):
|
| 129 |
+
# self.safe_sync("start")
|
| 130 |
+
|
| 131 |
+
device = image_0_features.device
|
| 132 |
+
|
| 133 |
+
# gather features
|
| 134 |
+
if self.cfg.is_distributed:
|
| 135 |
+
image_0_features = self.gather_features(image_0_features)
|
| 136 |
+
image_1_features = self.gather_features(image_1_features)
|
| 137 |
+
text_features = self.gather_features(text_features)
|
| 138 |
+
label_0 = self.gather_features(label_0)
|
| 139 |
+
label_1 = self.gather_features(label_1)
|
| 140 |
+
num_examples_per_prompt = self.gather_features(num_examples_per_prompt)
|
| 141 |
+
|
| 142 |
+
# calc logits # TODO use local loss as open-clip does
|
| 143 |
+
all_image_features = torch.cat([image_0_features, image_1_features], dim=0) # (2 * batch_size, dim)
|
| 144 |
+
logits_per_image = logit_scale * all_image_features @ text_features.T
|
| 145 |
+
image_0_logits, image_1_logits = logits_per_image.chunk(2, dim=0)
|
| 146 |
+
text_logits = logit_scale * text_features @ all_image_features.T
|
| 147 |
+
|
| 148 |
+
if self.cfg.in_batch_negatives:
|
| 149 |
+
# get labels
|
| 150 |
+
num_images = all_image_features.shape[0]
|
| 151 |
+
image_labels = torch.arange(num_images, device=device, dtype=torch.long)
|
| 152 |
+
image_0_labels, image_1_labels = image_labels.chunk(2, dim=0)
|
| 153 |
+
num_texts = text_features.shape[0]
|
| 154 |
+
text_labels = torch.arange(num_texts, device=device, dtype=torch.long)
|
| 155 |
+
|
| 156 |
+
# image loss - we want to increase the logits of the preferred image to the text
|
| 157 |
+
image_0_loss = torch.nn.functional.cross_entropy(image_0_logits, text_labels, reduction="none")
|
| 158 |
+
image_1_loss = torch.nn.functional.cross_entropy(image_1_logits, text_labels, reduction="none")
|
| 159 |
+
# if we have a tie, we will increase both images equally, and average so the image loss of each example is
|
| 160 |
+
# proportional
|
| 161 |
+
image_loss = label_0 * image_0_loss + label_1 * image_1_loss
|
| 162 |
+
|
| 163 |
+
# text loss - we want to increase the logits of the text to the preferred image
|
| 164 |
+
text_0_loss = torch.nn.functional.cross_entropy(text_logits, image_0_labels, reduction="none")
|
| 165 |
+
text_1_loss = torch.nn.functional.cross_entropy(text_logits, image_1_labels, reduction="none")
|
| 166 |
+
|
| 167 |
+
else:
|
| 168 |
+
text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1)
|
| 169 |
+
index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long)
|
| 170 |
+
|
| 171 |
+
text_0_logits = text_0_logits[index, index]
|
| 172 |
+
text_1_logits = text_1_logits[index, index]
|
| 173 |
+
text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
|
| 174 |
+
text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long)
|
| 175 |
+
text_1_labels = text_0_labels + 1
|
| 176 |
+
text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none")
|
| 177 |
+
text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none")
|
| 178 |
+
|
| 179 |
+
# if we have a tie we want the logits of for each image to be equal
|
| 180 |
+
text_loss = label_0 * text_0_loss + label_1 * text_1_loss
|
| 181 |
+
# we want the ideal loss to be 0, currently, if there is a tie, it is 0.5 * log(0.5) + 0.5 * log(0.5)
|
| 182 |
+
# so we add log(0.5) to the loss
|
| 183 |
+
is_tie = (label_0 == label_1).float()
|
| 184 |
+
is_tie *= torch.log(torch.tensor(0.5, device=device))
|
| 185 |
+
text_loss += is_tie
|
| 186 |
+
|
| 187 |
+
# we average the image and text loss
|
| 188 |
+
if self.cfg.in_batch_negatives:
|
| 189 |
+
loss = (image_loss + text_loss) / 2
|
| 190 |
+
else:
|
| 191 |
+
loss = text_loss
|
| 192 |
+
# import pdb; pdb.set_trace()
|
| 193 |
+
|
| 194 |
+
# some prompts have lots of interactions, we want weight them accordingly
|
| 195 |
+
# absolute_example_weight = 1 / num_examples_per_prompt
|
| 196 |
+
# denominator = absolute_example_weight.sum()
|
| 197 |
+
# weight_per_example = absolute_example_weight / denominator
|
| 198 |
+
# loss *= weight_per_example
|
| 199 |
+
loss = loss.mean()
|
| 200 |
+
# import pdb; pdb.set_trace()
|
| 201 |
+
|
| 202 |
+
# loss = loss.sum()
|
| 203 |
+
return loss
|
| 204 |
+
|
| 205 |
+
def forward(self, model, batch):
|
| 206 |
+
# import pdb; pdb.set_trace()
|
| 207 |
+
image_0_features, image_1_features, text_features = self.get_features(
|
| 208 |
+
model,
|
| 209 |
+
batch[self.cfg.input_ids_column_name],
|
| 210 |
+
batch[self.cfg.pixels_0_column_name],
|
| 211 |
+
batch[self.cfg.pixels_1_column_name]
|
| 212 |
+
)
|
| 213 |
+
# print("text_features:", text_features.shape)
|
| 214 |
+
|
| 215 |
+
loss = self.calc_loss(
|
| 216 |
+
text_features,
|
| 217 |
+
image_0_features,
|
| 218 |
+
image_1_features,
|
| 219 |
+
model.logit_scale.exp(),
|
| 220 |
+
batch[self.cfg.label_0_column_name],
|
| 221 |
+
batch[self.cfg.label_1_column_name],
|
| 222 |
+
batch[self.cfg.num_examples_per_prompt_column_name],
|
| 223 |
+
)
|
| 224 |
+
return loss
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# ====== 数据准备 ======
|
| 228 |
+
class QwenSD3JsonDataset(Dataset):
|
| 229 |
+
def __init__(self, processor, json_file, qwen_dir, sd3_dir):
|
| 230 |
+
"""
|
| 231 |
+
json_file: prompt2img.json {prompt: filename}
|
| 232 |
+
qwen_dir: 存放Qwen图像的文件夹
|
| 233 |
+
sd3_dir: 存放SD3图像的文件夹
|
| 234 |
+
"""
|
| 235 |
+
self.processor = processor
|
| 236 |
+
|
| 237 |
+
with open(json_file, "r") as f:
|
| 238 |
+
self.prompt2img = json.load(f)
|
| 239 |
+
|
| 240 |
+
self.prompts = list(self.prompt2img.keys())
|
| 241 |
+
self.qwen_dir = qwen_dir
|
| 242 |
+
self.sd3_dir = sd3_dir
|
| 243 |
+
|
| 244 |
+
def __len__(self):
|
| 245 |
+
return len(self.prompts)
|
| 246 |
+
|
| 247 |
+
def __getitem__(self, idx):
|
| 248 |
+
prompt = self.prompts[idx]
|
| 249 |
+
filename = self.prompt2img[prompt]
|
| 250 |
+
|
| 251 |
+
qwen_img_path = os.path.join(self.qwen_dir, filename)
|
| 252 |
+
sd3_img_path = os.path.join(self.sd3_dir, filename)
|
| 253 |
+
|
| 254 |
+
if os.path.exists(qwen_img_path) and os.path.exists(sd3_img_path):
|
| 255 |
+
qwen_img = Image.open(qwen_img_path).convert("RGB")
|
| 256 |
+
sd3_img = Image.open(sd3_img_path).convert("RGB")
|
| 257 |
+
else:
|
| 258 |
+
qwen_img = Image.open(sd3_img_path).convert("RGB")
|
| 259 |
+
sd3_img = Image.open(sd3_img_path).convert("RGB")
|
| 260 |
+
|
| 261 |
+
# 文本token
|
| 262 |
+
text_inputs = self.processor.tokenizer(
|
| 263 |
+
prompt,
|
| 264 |
+
padding="max_length",
|
| 265 |
+
truncation=True,
|
| 266 |
+
max_length=77,
|
| 267 |
+
return_tensors="pt"
|
| 268 |
+
)
|
| 269 |
+
input_ids = text_inputs["input_ids"].squeeze(0)
|
| 270 |
+
|
| 271 |
+
# 图像预处理
|
| 272 |
+
pixels_0 = self.processor(images=qwen_img, return_tensors="pt")["pixel_values"].squeeze(0)
|
| 273 |
+
pixels_1 = self.processor(images=sd3_img, return_tensors="pt")["pixel_values"].squeeze(0)
|
| 274 |
+
|
| 275 |
+
return {
|
| 276 |
+
"input_ids": input_ids,
|
| 277 |
+
"pixels_0": pixels_0, # 正样本 (Qwen)
|
| 278 |
+
"pixels_1": pixels_1, # 负样本 (SD3)
|
| 279 |
+
"label_0": torch.tensor(1.0),
|
| 280 |
+
"label_1": torch.tensor(0.0),
|
| 281 |
+
"num_examples_per_prompt": torch.tensor(1.0)
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# ====== 训练 loop ======
|
| 286 |
+
# def finetune_pickscore(json_file, qwen_dir, sd3_dir, epochs=2, batch_size=4, lr=1e-6, device="cuda"):
|
| 287 |
+
# processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
| 288 |
+
# model = CLIPModel.from_pretrained("yuvalkirstain/PickScore_v1").to(device)
|
| 289 |
+
|
| 290 |
+
# dataset = QwenSD3JsonDataset(processor,json_file, qwen_dir, sd3_dir)
|
| 291 |
+
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
| 292 |
+
|
| 293 |
+
# criterion = CLIPCriterion(CLIPCriterionConfig())
|
| 294 |
+
# optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 295 |
+
# # import pdb; pdb.set_trace()
|
| 296 |
+
|
| 297 |
+
# model.train()
|
| 298 |
+
# for epoch in range(epochs):
|
| 299 |
+
# total_loss = 0.0
|
| 300 |
+
# for batch in dataloader:
|
| 301 |
+
# batch = {k: v.to(device) for k, v in batch.items()}
|
| 302 |
+
# loss = criterion(model, batch)
|
| 303 |
+
|
| 304 |
+
# optimizer.zero_grad()
|
| 305 |
+
# loss.backward()
|
| 306 |
+
# optimizer.step()
|
| 307 |
+
|
| 308 |
+
# total_loss += loss.item()
|
| 309 |
+
# print(f"Epoch {epoch} | Loss {total_loss/len(dataloader):.4f}")
|
| 310 |
+
|
| 311 |
+
# model.save_pretrained("pickscore_qwen_finetuned")
|
| 312 |
+
# return model
|
| 313 |
+
|
| 314 |
+
def finetune_pickscore_distributed(json_file, qwen_dir, sd3_dir, epochs=2, batch_size=4, lr=1e-6):
|
| 315 |
+
# 1. 初始化分布式
|
| 316 |
+
dist.init_process_group(backend="nccl")
|
| 317 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 318 |
+
torch.cuda.set_device(local_rank)
|
| 319 |
+
device = torch.device("cuda", local_rank)
|
| 320 |
+
|
| 321 |
+
# 2. 准备数据
|
| 322 |
+
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
| 323 |
+
dataset = QwenSD3JsonDataset(processor, json_file, qwen_dir, sd3_dir)
|
| 324 |
+
sampler = DistributedSampler(dataset)
|
| 325 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
|
| 326 |
+
|
| 327 |
+
# 3. 模型 + DDP
|
| 328 |
+
model = CLIPModel.from_pretrained("yuvalkirstain/PickScore_v1").to(device)
|
| 329 |
+
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
|
| 330 |
+
|
| 331 |
+
criterion = CLIPCriterion(CLIPCriterionConfig())
|
| 332 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 333 |
+
|
| 334 |
+
# 4. 训练
|
| 335 |
+
model.train()
|
| 336 |
+
if dist.get_rank() == 0:
|
| 337 |
+
evaluate_pickscore(model, processor, json_file, qwen_dir, sd3_dir, device)
|
| 338 |
+
for epoch in range(epochs):
|
| 339 |
+
sampler.set_epoch(epoch) # 保证每个 epoch shuffle 一样
|
| 340 |
+
total_loss = 0.0
|
| 341 |
+
|
| 342 |
+
for step, batch in enumerate(dataloader):
|
| 343 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 344 |
+
loss = criterion(model.module, batch)
|
| 345 |
+
|
| 346 |
+
optimizer.zero_grad()
|
| 347 |
+
loss.backward()
|
| 348 |
+
optimizer.step()
|
| 349 |
+
|
| 350 |
+
# 累积loss(先local)
|
| 351 |
+
total_loss += loss.item()
|
| 352 |
+
|
| 353 |
+
# 每隔一定步打印一次(rank=0)
|
| 354 |
+
if step % 50 == 0: # 你可以改成10、100
|
| 355 |
+
# all_reduce 把所有 GPU 的 loss 平均
|
| 356 |
+
avg_loss = torch.tensor(loss.item(), device=device)
|
| 357 |
+
dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
|
| 358 |
+
if dist.get_rank() == 0:
|
| 359 |
+
print(f"[Epoch {epoch} | Step {step}/{len(dataloader)}] "
|
| 360 |
+
f"local_loss={loss.item():.4f} | avg_loss={avg_loss.item():.4f}")
|
| 361 |
+
|
| 362 |
+
# 每个 epoch 打印 epoch 平均 loss
|
| 363 |
+
epoch_loss = torch.tensor(total_loss / len(dataloader), device=device)
|
| 364 |
+
dist.all_reduce(epoch_loss, op=dist.ReduceOp.AVG)
|
| 365 |
+
if dist.get_rank() == 0:
|
| 366 |
+
print(f"===> Epoch {epoch} done | avg_epoch_loss={epoch_loss.item():.4f}")
|
| 367 |
+
evaluate_pickscore(model, processor, json_file, qwen_dir, sd3_dir, device)
|
| 368 |
+
|
| 369 |
+
# 5. 保存模型(只在 rank=0)
|
| 370 |
+
if dist.get_rank() == 0:
|
| 371 |
+
model.module.save_pretrained("pickscore_qwen_finetuned")
|
| 372 |
+
|
| 373 |
+
dist.destroy_process_group()
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
# ====== 用法示例 ======
|
| 377 |
+
if __name__ == "__main__":
|
| 378 |
+
finetune_pickscore_distributed(
|
| 379 |
+
json_file="/mnt/bn/vgfm2/test_dit/weijia/outputs/sd3_images/prompt2img.json",
|
| 380 |
+
qwen_dir="/mnt/bn/vgfm2/test_dit/weijia/outputs/qwen_images",
|
| 381 |
+
sd3_dir="/mnt/bn/vgfm2/test_dit/weijia/outputs/sd3_images",
|
| 382 |
+
epochs=2,
|
| 383 |
+
batch_size=4,
|
| 384 |
+
lr=1e-6,
|
| 385 |
+
)
|
adv_grpo/pickscore_scorer.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class PickScoreScorer(torch.nn.Module):
|
| 6 |
+
def __init__(self, device="cuda", dtype=torch.float32):
|
| 7 |
+
super().__init__()
|
| 8 |
+
processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
| 9 |
+
model_path = "yuvalkirstain/PickScore_v1"
|
| 10 |
+
self.device = device
|
| 11 |
+
self.dtype = dtype
|
| 12 |
+
self.processor = CLIPProcessor.from_pretrained(processor_path)
|
| 13 |
+
self.model = CLIPModel.from_pretrained(model_path).eval().to(device)
|
| 14 |
+
self.model = self.model.to(dtype=dtype)
|
| 15 |
+
|
| 16 |
+
@torch.no_grad()
|
| 17 |
+
def __call__(self, prompt, images):
|
| 18 |
+
# Preprocess images
|
| 19 |
+
if hasattr(self.model, "module"):
|
| 20 |
+
self.model = self.model.module
|
| 21 |
+
image_inputs = self.processor(
|
| 22 |
+
images=images,
|
| 23 |
+
padding=True,
|
| 24 |
+
truncation=True,
|
| 25 |
+
max_length=77,
|
| 26 |
+
return_tensors="pt",
|
| 27 |
+
)
|
| 28 |
+
image_inputs = {k: v.to(device=self.device) for k, v in image_inputs.items()}
|
| 29 |
+
# Preprocess text
|
| 30 |
+
text_inputs = self.processor(
|
| 31 |
+
text=prompt,
|
| 32 |
+
padding=True,
|
| 33 |
+
truncation=True,
|
| 34 |
+
max_length=77,
|
| 35 |
+
return_tensors="pt",
|
| 36 |
+
)
|
| 37 |
+
text_inputs = {k: v.to(device=self.device) for k, v in text_inputs.items()}
|
| 38 |
+
|
| 39 |
+
# Get embeddings
|
| 40 |
+
image_embs = self.model.get_image_features(**image_inputs)
|
| 41 |
+
image_embs = image_embs / image_embs.norm(p=2, dim=-1, keepdim=True)
|
| 42 |
+
|
| 43 |
+
text_embs = self.model.get_text_features(**text_inputs)
|
| 44 |
+
text_embs = text_embs / text_embs.norm(p=2, dim=-1, keepdim=True)
|
| 45 |
+
|
| 46 |
+
# Calculate scores
|
| 47 |
+
logit_scale = self.model.logit_scale.exp()
|
| 48 |
+
scores = logit_scale * (text_embs @ image_embs.T)
|
| 49 |
+
scores = scores.diag()
|
| 50 |
+
# norm to 0-1
|
| 51 |
+
scores = scores/26
|
| 52 |
+
return scores
|
| 53 |
+
|
| 54 |
+
# Usage example
|
| 55 |
+
def main():
|
| 56 |
+
scorer = PickScoreScorer(
|
| 57 |
+
device="cuda",
|
| 58 |
+
dtype=torch.float32
|
| 59 |
+
)
|
| 60 |
+
images=[
|
| 61 |
+
"nasa.jpg",
|
| 62 |
+
]
|
| 63 |
+
pil_images = [Image.open(img) for img in images]
|
| 64 |
+
prompts=[
|
| 65 |
+
'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
|
| 66 |
+
]
|
| 67 |
+
print(scorer(prompts, pil_images))
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
main()
|
adv_grpo/pickscore_scorer_constractive.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class PickScoreScorerConstractive(torch.nn.Module):
|
| 6 |
+
def __init__(self, device="cuda", dtype=torch.float32):
|
| 7 |
+
super().__init__()
|
| 8 |
+
processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
| 9 |
+
model_path = "yuvalkirstain/PickScore_v1"
|
| 10 |
+
self.device = device
|
| 11 |
+
self.dtype = dtype
|
| 12 |
+
self.processor = CLIPProcessor.from_pretrained(processor_path)
|
| 13 |
+
self.model = CLIPModel.from_pretrained(model_path).eval().to(device)
|
| 14 |
+
self.model = self.model.to(dtype=dtype)
|
| 15 |
+
|
| 16 |
+
@torch.no_grad()
|
| 17 |
+
def __call__(self, prompt, images, ref_images):
|
| 18 |
+
# Preprocess images
|
| 19 |
+
image_inputs = self.processor(
|
| 20 |
+
images=images,
|
| 21 |
+
padding=True,
|
| 22 |
+
truncation=True,
|
| 23 |
+
max_length=77,
|
| 24 |
+
return_tensors="pt",
|
| 25 |
+
)
|
| 26 |
+
image_inputs = {k: v.to(device=self.device) for k, v in image_inputs.items()}
|
| 27 |
+
|
| 28 |
+
ref_image_inputs = self.processor(
|
| 29 |
+
images=ref_images,
|
| 30 |
+
padding=True,
|
| 31 |
+
truncation=True,
|
| 32 |
+
max_length=77,
|
| 33 |
+
return_tensors="pt",
|
| 34 |
+
)
|
| 35 |
+
ref_image_inputs = {k: v.to(device=self.device) for k, v in ref_image_inputs.items()}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Preprocess text
|
| 40 |
+
text_inputs = self.processor(
|
| 41 |
+
text=prompt,
|
| 42 |
+
padding=True,
|
| 43 |
+
truncation=True,
|
| 44 |
+
max_length=77,
|
| 45 |
+
return_tensors="pt",
|
| 46 |
+
)
|
| 47 |
+
text_inputs = {k: v.to(device=self.device) for k, v in text_inputs.items()}
|
| 48 |
+
|
| 49 |
+
# Get embeddings
|
| 50 |
+
image_embs = self.model.get_image_features(**image_inputs)
|
| 51 |
+
image_embs = image_embs / image_embs.norm(p=2, dim=-1, keepdim=True)
|
| 52 |
+
|
| 53 |
+
ref_image_embs = self.model.get_image_features(**ref_image_inputs)
|
| 54 |
+
ref_image_embs = ref_image_embs / ref_image_embs.norm(p=2, dim=-1, keepdim=True)
|
| 55 |
+
|
| 56 |
+
text_embs = self.model.get_text_features(**text_inputs)
|
| 57 |
+
text_embs = text_embs / text_embs.norm(p=2, dim=-1, keepdim=True)
|
| 58 |
+
|
| 59 |
+
# Calculate scores
|
| 60 |
+
logit_scale = self.model.logit_scale.exp()
|
| 61 |
+
scores = logit_scale * (text_embs @ image_embs.T)
|
| 62 |
+
scores = scores.diag()
|
| 63 |
+
# norm to 0-1
|
| 64 |
+
scores = scores/26
|
| 65 |
+
|
| 66 |
+
ref_scores = logit_scale * (text_embs @ ref_image_embs.T)
|
| 67 |
+
ref_scores = ref_scores.diag()
|
| 68 |
+
ref_scores = ref_scores/26
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
return scores, ref_scores, image_embs, ref_image_embs
|
| 72 |
+
|
| 73 |
+
# Usage example
|
| 74 |
+
def main():
|
| 75 |
+
scorer = PickScoreScorer(
|
| 76 |
+
device="cuda",
|
| 77 |
+
dtype=torch.float32
|
| 78 |
+
)
|
| 79 |
+
images=[
|
| 80 |
+
"nasa.jpg",
|
| 81 |
+
]
|
| 82 |
+
pil_images = [Image.open(img) for img in images]
|
| 83 |
+
prompts=[
|
| 84 |
+
'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
|
| 85 |
+
]
|
| 86 |
+
print(scorer(prompts, pil_images))
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
main()
|
adv_grpo/pickscore_scorer_patch.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class PickScoreScorer(torch.nn.Module):
|
| 6 |
+
def __init__(self, device="cuda", dtype=torch.float32):
|
| 7 |
+
super().__init__()
|
| 8 |
+
processor_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
| 9 |
+
model_path = "yuvalkirstain/PickScore_v1"
|
| 10 |
+
self.device = device
|
| 11 |
+
self.dtype = dtype
|
| 12 |
+
self.processor = CLIPProcessor.from_pretrained(processor_path)
|
| 13 |
+
self.model = CLIPModel.from_pretrained(model_path).eval().to(device)
|
| 14 |
+
self.model = self.model.to(dtype=dtype)
|
| 15 |
+
|
| 16 |
+
@torch.no_grad()
|
| 17 |
+
def __call__(self, prompt, images):
|
| 18 |
+
# Preprocess images
|
| 19 |
+
if hasattr(self.model, "module"):
|
| 20 |
+
self.model = self.model.module
|
| 21 |
+
image_inputs = self.processor(
|
| 22 |
+
images=images,
|
| 23 |
+
padding=True,
|
| 24 |
+
truncation=True,
|
| 25 |
+
max_length=77,
|
| 26 |
+
return_tensors="pt",
|
| 27 |
+
)
|
| 28 |
+
image_inputs = {k: v.to(device=self.device) for k, v in image_inputs.items()}
|
| 29 |
+
# Preprocess text
|
| 30 |
+
text_inputs = self.processor(
|
| 31 |
+
text=prompt,
|
| 32 |
+
padding=True,
|
| 33 |
+
truncation=True,
|
| 34 |
+
max_length=77,
|
| 35 |
+
return_tensors="pt",
|
| 36 |
+
)
|
| 37 |
+
text_inputs = {k: v.to(device=self.device) for k, v in text_inputs.items()}
|
| 38 |
+
|
| 39 |
+
# Get embeddings
|
| 40 |
+
# image_embs = self.model.get_image_features(**image_inputs)
|
| 41 |
+
import pdb; pdb.set_trace()
|
| 42 |
+
image_embs = self.model.vision_model(image_inputs["pixel_values"],output_hidden_states=True)
|
| 43 |
+
image_embs = image_embs.last_hidden_state
|
| 44 |
+
|
| 45 |
+
image_embs = self.model.visual_projection(image_embs) # [B, N, 1024]
|
| 46 |
+
image_embs = image_embs / image_embs.norm(p=2, dim=-1, keepdim=True)
|
| 47 |
+
|
| 48 |
+
text_embs = self.model.get_text_features(**text_inputs)
|
| 49 |
+
text_embs = text_embs / text_embs.norm(p=2, dim=-1, keepdim=True)
|
| 50 |
+
|
| 51 |
+
# Calculate scores
|
| 52 |
+
logit_scale = self.model.logit_scale.exp()
|
| 53 |
+
# scores = logit_scale * (text_embs @ image_embs.T)
|
| 54 |
+
patch_scores = torch.einsum("bd,bnd->bn", text_embs, image_embs) # [B, N]
|
| 55 |
+
scores = logit_scale * patch_scores.mean(dim=1) # 取所有 patch 的平均
|
| 56 |
+
# scores = scores.diag()
|
| 57 |
+
# norm to 0-1
|
| 58 |
+
scores = scores/26
|
| 59 |
+
# import pdb; pdb.set_trace()
|
| 60 |
+
return scores
|
| 61 |
+
|
| 62 |
+
# Usage example
|
| 63 |
+
def main():
|
| 64 |
+
scorer = PickScoreScorer(
|
| 65 |
+
device="cuda",
|
| 66 |
+
dtype=torch.float32
|
| 67 |
+
)
|
| 68 |
+
images=[
|
| 69 |
+
"nasa.jpg",
|
| 70 |
+
]
|
| 71 |
+
pil_images = [Image.open(img) for img in images]
|
| 72 |
+
prompts=[
|
| 73 |
+
'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
|
| 74 |
+
]
|
| 75 |
+
print(scorer(prompts, pil_images))
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
main()
|