Update README.md
Browse filesyou have 7 days from this commit to backup old models before they get culled
README.md
CHANGED
|
@@ -6,60 +6,73 @@ license: mit
|
|
| 6 |
## Training Details
|
| 7 |
Training details can be found here: https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx
|
| 8 |
## Source Code
|
| 9 |
-
Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch
|
|
|
|
| 10 |
## Community: LAION
|
| 11 |
Join Us!: https://discord.gg/uPMftTmrvS
|
| 12 |
|
| 13 |
---
|
| 14 |
|
| 15 |
# Models
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
clip = OpenAIClipAdapter(clip_choice=["ViT-L/14" | "ViT-B/32"])
|
| 20 |
-
```
|
| 21 |
|
| 22 |
### Loading the models might look something like this:
|
|
|
|
|
|
|
|
|
|
| 23 |
```python
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
|
| 28 |
-
|
| 29 |
-
dim = 512
|
| 30 |
-
else:
|
| 31 |
-
dim = 768
|
| 32 |
-
|
| 33 |
prior_network = DiffusionPriorNetwork(
|
| 34 |
-
dim=
|
| 35 |
-
depth=
|
| 36 |
dim_head=64,
|
| 37 |
-
heads=
|
| 38 |
-
normformer=True
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
diffusion_prior = DiffusionPrior(
|
| 42 |
net=prior_network,
|
| 43 |
-
clip=OpenAIClipAdapter(
|
| 44 |
-
image_embed_dim=
|
| 45 |
timesteps=1000,
|
| 46 |
cond_drop_prob=0.1,
|
| 47 |
loss_type="l2",
|
| 48 |
-
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
```
|
|
|
|
| 6 |
## Training Details
|
| 7 |
Training details can be found here: https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx
|
| 8 |
## Source Code
|
| 9 |
+
Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch
|
| 10 |
+
|
| 11 |
## Community: LAION
|
| 12 |
Join Us!: https://discord.gg/uPMftTmrvS
|
| 13 |
|
| 14 |
---
|
| 15 |
|
| 16 |
# Models
|
| 17 |
+
The repo currently has many models (most of which are actually pretty bad). I recommend using the latest ema checkpoints for now.
|
| 18 |
+
|
| 19 |
+
> **_DISCLAIMER_**: **I will be removing many of the older models**. They were trained on older versions of the repo and massively under perform recent models. **If for whatever reason you want an old model please make a backup** (you have 7 days from this README commit timestamp).
|
|
|
|
|
|
|
| 20 |
|
| 21 |
### Loading the models might look something like this:
|
| 22 |
+
|
| 23 |
+
> Note: This repo's documentation will get an overhaul \~soon\~. If you're reading this, and having issues loading checkpoints, please reach out on LAION.
|
| 24 |
+
|
| 25 |
```python
|
| 26 |
+
import torch
|
| 27 |
+
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
| 28 |
+
from dalle2_pytorch.trainer import DiffusionPriorTrainer
|
| 29 |
|
| 30 |
+
def load_diffusion_model(dprior_path, device):
|
| 31 |
|
| 32 |
+
# If you are getting issues with size mismatches, it's likely this configuration
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
prior_network = DiffusionPriorNetwork(
|
| 34 |
+
dim=768,
|
| 35 |
+
depth=24,
|
| 36 |
dim_head=64,
|
| 37 |
+
heads=32,
|
| 38 |
+
normformer=True,
|
| 39 |
+
attn_dropout=5e-2,
|
| 40 |
+
ff_dropout=5e-2,
|
| 41 |
+
num_time_embeds=1,
|
| 42 |
+
num_image_embeds=1,
|
| 43 |
+
num_text_embeds=1,
|
| 44 |
+
num_timesteps=1000,
|
| 45 |
+
ff_mult=4
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# currently, only ViT-L/14 models are being trained
|
| 49 |
diffusion_prior = DiffusionPrior(
|
| 50 |
net=prior_network,
|
| 51 |
+
clip=OpenAIClipAdapter("ViT-L/14"),
|
| 52 |
+
image_embed_dim=768,
|
| 53 |
timesteps=1000,
|
| 54 |
cond_drop_prob=0.1,
|
| 55 |
loss_type="l2",
|
| 56 |
+
condition_on_text_encodings=True,
|
| 57 |
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# this will load the entire trainer
|
| 61 |
+
# If you only want EMA weights for inference you will need to extract them yourself for now
|
| 62 |
+
# (if you beat me to writing a nice function for that please make a PR on Github!)
|
| 63 |
+
trainer = DiffusionPriorTrainer(
|
| 64 |
+
diffusion_prior=diffusion_prior,
|
| 65 |
+
lr=1.1e-4,
|
| 66 |
+
wd=6.02e-2,
|
| 67 |
+
max_grad_norm=0.5,
|
| 68 |
+
amp=False,
|
| 69 |
+
group_wd_params=True,
|
| 70 |
+
use_ema=True,
|
| 71 |
+
device=device,
|
| 72 |
+
accelerator=None,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
trainer.load(dprior_path)
|
| 76 |
+
|
| 77 |
+
return trainer
|
| 78 |
```
|