Spaces:
Sleeping
Sleeping
Commit ·
3fbcf45
0
Parent(s):
working
Browse files- .gitignore +2 -0
- README.md +0 -0
- app.py +0 -0
- assets/background.svg +23 -0
- assets/favicon.svg +5 -0
- assets/index.css +15 -0
- diffusers_lora_finetune.py +549 -0
- instance_example_urls.txt +3 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
.venv
|
README.md
ADDED
|
File without changes
|
app.py
ADDED
|
File without changes
|
assets/background.svg
ADDED
|
|
assets/favicon.svg
ADDED
|
|
assets/index.css
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Bit of Modal Labs color scheming for the Gradio.app UI
|
| 2 |
+
|
| 3 |
+
from https://github.com/modal-labs/modal-examples */
|
| 4 |
+
|
| 5 |
+
a {
|
| 6 |
+
text-decoration: inherit !important;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
gradio-app {
|
| 10 |
+
background-image: url(/assets/background.svg) !important;
|
| 11 |
+
background-repeat: no-repeat !important;
|
| 12 |
+
background-size 100% auto;
|
| 13 |
+
padding-top: 3%;
|
| 14 |
+
background-color: black;
|
| 15 |
+
}
|
diffusers_lora_finetune.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---
|
| 2 |
+
# deploy: true
|
| 3 |
+
# ---
|
| 4 |
+
|
| 5 |
+
# # Fine-tune Flux on your pet using LoRA
|
| 6 |
+
|
| 7 |
+
# This example finetunes the [Flux.1-dev model](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
| 8 |
+
# on images of a pet (by default, a puppy named Qwerty)
|
| 9 |
+
# using a technique called textual inversion from [the "Dreambooth" paper](https://dreambooth.github.io/).
|
| 10 |
+
# Effectively, it teaches a general image generation model a new "proper noun",
|
| 11 |
+
# allowing for the personalized generation of art and photos.
|
| 12 |
+
# We supplement textual inversion with low-rank adaptation (LoRA)
|
| 13 |
+
# for increased efficiency during training.
|
| 14 |
+
|
| 15 |
+
# It then makes the model shareable with others -- without costing $25/day for a GPU server--
|
| 16 |
+
# by hosting a [Gradio app](https://gradio.app/) on Modal.
|
| 17 |
+
|
| 18 |
+
# It demonstrates a simple, productive, and cost-effective pathway
|
| 19 |
+
# to building on large pretrained models using Modal's building blocks, like
|
| 20 |
+
# [GPU-accelerated](https://modal.com/docs/guide/gpu) Modal Functions and Clses for compute-intensive work,
|
| 21 |
+
# [Volumes](https://modal.com/docs/guide/volumes) for storage,
|
| 22 |
+
# and [web endpoints](https://modal.com/docs/guide/webhooks) for serving.
|
| 23 |
+
|
| 24 |
+
# And with some light customization, you can use it to generate images of your pet!
|
| 25 |
+
|
| 26 |
+
# 
|
| 27 |
+
|
| 28 |
+
# You can find a video walkthrough of this example on the Modal YouTube channel
|
| 29 |
+
# [here](https://www.youtube.com/watch?v=df-8fiByXMI).
|
| 30 |
+
|
| 31 |
+
# ## Imports and setup
|
| 32 |
+
|
| 33 |
+
# We start by importing the necessary libraries and setting up the environment.
|
| 34 |
+
|
| 35 |
+
from dataclasses import dataclass
|
| 36 |
+
from pathlib import Path
|
| 37 |
+
|
| 38 |
+
import modal
|
| 39 |
+
|
| 40 |
+
# ## Building up the environment
|
| 41 |
+
|
| 42 |
+
# Machine learning environments are complex, and the dependencies can be hard to manage.
|
| 43 |
+
# Modal makes creating and working with environments easy via
|
| 44 |
+
# [containers and container images](https://modal.com/docs/guide/custom-container).
|
| 45 |
+
|
| 46 |
+
# We start from a base image and specify all of our dependencies.
|
| 47 |
+
# We'll call out the interesting ones as they come up below.
|
| 48 |
+
# Note that these dependencies are not installed locally
|
| 49 |
+
# -- they are only installed in the remote environment where our Modal App runs.
|
| 50 |
+
|
| 51 |
+
app = modal.App(name="jason-lora-flux")
|
| 52 |
+
|
| 53 |
+
image = modal.Image.debian_slim(python_version="3.10").pip_install(
|
| 54 |
+
"accelerate==0.31.0",
|
| 55 |
+
"datasets~=2.13.0",
|
| 56 |
+
"fastapi[standard]==0.115.4",
|
| 57 |
+
"ftfy~=6.1.0",
|
| 58 |
+
"gradio~=5.5.0",
|
| 59 |
+
"huggingface-hub==0.26.2",
|
| 60 |
+
"hf_transfer==0.1.8",
|
| 61 |
+
"numpy<2",
|
| 62 |
+
"peft==0.11.1",
|
| 63 |
+
"pydantic==2.9.2",
|
| 64 |
+
"sentencepiece>=0.1.91,!=0.1.92",
|
| 65 |
+
"smart_open~=6.4.0",
|
| 66 |
+
"starlette==0.41.2",
|
| 67 |
+
"transformers~=4.41.2",
|
| 68 |
+
"torch~=2.2.0",
|
| 69 |
+
"torchvision~=0.16",
|
| 70 |
+
"triton~=2.2.0",
|
| 71 |
+
"wandb==0.17.6",
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# ### Downloading scripts and installing a git repo with `run_commands`
|
| 75 |
+
|
| 76 |
+
# We'll use an example script from the `diffusers` library to train the model.
|
| 77 |
+
# We acquire it from GitHub and install it in our environment with a series of commands.
|
| 78 |
+
# The container environments Modal Functions run in are highly flexible --
|
| 79 |
+
# see [the docs](https://modal.com/docs/guide/custom-container) for more details.
|
| 80 |
+
|
| 81 |
+
GIT_SHA = "e649678bf55aeaa4b60bd1f68b1ee726278c0304" # specify the commit to fetch
|
| 82 |
+
|
| 83 |
+
image = (
|
| 84 |
+
image.apt_install("git")
|
| 85 |
+
# Perform a shallow fetch of just the target `diffusers` commit, checking out
|
| 86 |
+
# the commit in the container's home directory, /root. Then install `diffusers`
|
| 87 |
+
.run_commands(
|
| 88 |
+
"cd /root && git init .",
|
| 89 |
+
"cd /root && git remote add origin https://github.com/huggingface/diffusers",
|
| 90 |
+
f"cd /root && git fetch --depth=1 origin {GIT_SHA} && git checkout {GIT_SHA}",
|
| 91 |
+
"cd /root && pip install -e .",
|
| 92 |
+
)
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# ### Configuration with `dataclass`es
|
| 96 |
+
|
| 97 |
+
# Machine learning apps often have a lot of configuration information.
|
| 98 |
+
# We collect up all of our configuration into dataclasses to avoid scattering special/magic values throughout code.
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@dataclass
|
| 102 |
+
class SharedConfig:
|
| 103 |
+
"""Configuration information shared across project components."""
|
| 104 |
+
|
| 105 |
+
# The instance name is the "proper noun" we're teaching the model
|
| 106 |
+
instance_name: str = "Qwerty"
|
| 107 |
+
# That proper noun is usually a member of some class (person, bird),
|
| 108 |
+
# and sharing that information with the model helps it generalize better.
|
| 109 |
+
class_name: str = "Golden Retriever"
|
| 110 |
+
# identifier for pretrained models on Hugging Face
|
| 111 |
+
model_name: str = "black-forest-labs/FLUX.1-dev"
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# ### Storing data created by our app with `modal.Volume`
|
| 115 |
+
|
| 116 |
+
# The tools we've used so far work well for fetching external information,
|
| 117 |
+
# which defines the environment our app runs in,
|
| 118 |
+
# but what about data that we create or modify during the app's execution?
|
| 119 |
+
# A persisted [`modal.Volume`](https://modal.com/docs/guide/volumes) can store and share data across Modal Apps and Functions.
|
| 120 |
+
|
| 121 |
+
# We'll use one to store both the original and fine-tuned weights we create during training
|
| 122 |
+
# and then load them back in for inference.
|
| 123 |
+
|
| 124 |
+
volume = modal.Volume.from_name(
|
| 125 |
+
"dreambooth-finetuning-volume-flux", create_if_missing=True
|
| 126 |
+
)
|
| 127 |
+
MODEL_DIR = "/model"
|
| 128 |
+
|
| 129 |
+
# Note that access to the Flux.1-dev model on Hugging Face is
|
| 130 |
+
# [gated by a license agreement](https://huggingface.co/docs/hub/en/models-gated) which
|
| 131 |
+
# you must agree to [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
|
| 132 |
+
# After you have accepted the license, [create a Modal Secret](https://modal.com/secrets)
|
| 133 |
+
# with the name `huggingface-secret` following the instructions in the template.
|
| 134 |
+
|
| 135 |
+
huggingface_secret = modal.Secret.from_name(
|
| 136 |
+
"huggingface-secret", required_keys=["HF_TOKEN"]
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
image = image.env(
|
| 140 |
+
{"HF_HUB_ENABLE_HF_TRANSFER": "1"} # turn on faster downloads from HF
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@app.function(
|
| 145 |
+
volumes={MODEL_DIR: volume},
|
| 146 |
+
image=image,
|
| 147 |
+
secrets=[huggingface_secret],
|
| 148 |
+
timeout=600, # 10 minutes
|
| 149 |
+
)
|
| 150 |
+
def download_models(config):
|
| 151 |
+
import torch
|
| 152 |
+
from diffusers import DiffusionPipeline
|
| 153 |
+
from huggingface_hub import snapshot_download
|
| 154 |
+
|
| 155 |
+
snapshot_download(
|
| 156 |
+
config.model_name,
|
| 157 |
+
local_dir=MODEL_DIR,
|
| 158 |
+
ignore_patterns=["*.pt", "*.bin"], # using safetensors
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
DiffusionPipeline.from_pretrained(MODEL_DIR, torch_dtype=torch.bfloat16)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# ### Load fine-tuning dataset
|
| 165 |
+
|
| 166 |
+
# Part of the magic of the low-rank fine-tuning is that we only need 3-10 images for fine-tuning.
|
| 167 |
+
# So we can fetch just a few images, stored on consumer platforms like Imgur or Google Drive,
|
| 168 |
+
# whenever we need them -- no need for expensive, hard-to-maintain data pipelines.
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def load_images(image_urls: list[str]) -> Path:
|
| 172 |
+
import PIL.Image
|
| 173 |
+
from smart_open import open
|
| 174 |
+
|
| 175 |
+
img_path = Path("/img")
|
| 176 |
+
|
| 177 |
+
img_path.mkdir(parents=True, exist_ok=True)
|
| 178 |
+
for ii, url in enumerate(image_urls):
|
| 179 |
+
with open(url, "rb") as f:
|
| 180 |
+
image = PIL.Image.open(f)
|
| 181 |
+
image.save(img_path / f"{ii}.png")
|
| 182 |
+
print(f"{ii + 1} images loaded")
|
| 183 |
+
|
| 184 |
+
return img_path
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# ## Low-Rank Adapation (LoRA) fine-tuning for a text-to-image model
|
| 188 |
+
|
| 189 |
+
# The base model we start from is trained to do a sort of "reverse [ekphrasis](https://en.wikipedia.org/wiki/Ekphrasis)":
|
| 190 |
+
# it attempts to recreate a visual work of art or image from only its description.
|
| 191 |
+
|
| 192 |
+
# We can use the model to synthesize wholly new images
|
| 193 |
+
# by combining the concepts it has learned from the training data.
|
| 194 |
+
|
| 195 |
+
# We use a pretrained model, the Flux model from Black Forest Labs.
|
| 196 |
+
# In this example, we "finetune" Flux, making only small adjustments to the weights.
|
| 197 |
+
# Furthermore, we don't change all the weights in the model.
|
| 198 |
+
# Instead, using a technique called [_low-rank adaptation_](https://arxiv.org/abs/2106.09685),
|
| 199 |
+
# we change a much smaller matrix that works "alongside" the existing weights, nudging the model in the direction we want.
|
| 200 |
+
|
| 201 |
+
# We can get away with such a small and simple training process because we're just teach the model the meaning of a single new word: the name of our pet.
|
| 202 |
+
|
| 203 |
+
# The result is a model that can generate novel images of our pet:
|
| 204 |
+
# as an astronaut in space, as painted by Van Gogh or Bastiat, etc.
|
| 205 |
+
|
| 206 |
+
# ### Finetuning with Hugging Face 🧨 Diffusers and Accelerate
|
| 207 |
+
|
| 208 |
+
# The model weights, training libraries, and training script are all provided by [🤗 Hugging Face](https://huggingface.co).
|
| 209 |
+
|
| 210 |
+
# You can kick off a training job with the command `modal run dreambooth_app.py::app.train`.
|
| 211 |
+
# It should take about ten minutes.
|
| 212 |
+
|
| 213 |
+
# Training machine learning models takes time and produces a lot of metadata --
|
| 214 |
+
# metrics for performance and resource utilization,
|
| 215 |
+
# metrics for model quality and training stability,
|
| 216 |
+
# and model inputs and outputs like images and text.
|
| 217 |
+
# This is especially important if you're fiddling around with the configuration parameters.
|
| 218 |
+
|
| 219 |
+
# This example can optionally use [Weights & Biases](https://wandb.ai) to track all of this training information.
|
| 220 |
+
# Just sign up for an account, switch the flag below, and add your API key as a [Modal Secret](https://modal.com/secrets).
|
| 221 |
+
|
| 222 |
+
USE_WANDB = False
|
| 223 |
+
|
| 224 |
+
# You can see an example W&B dashboard [here](https://wandb.ai/cfrye59/dreambooth-lora-sd-xl).
|
| 225 |
+
# Check out [this run](https://wandb.ai/cfrye59/dreambooth-lora-sd-xl/runs/ca3v1lsh?workspace=user-cfrye59),
|
| 226 |
+
# which [despite having high GPU utilization](https://wandb.ai/cfrye59/dreambooth-lora-sd-xl/runs/ca3v1lsh/system)
|
| 227 |
+
# suffered from numerical instability during training and produced only black images -- hard to debug without experiment management logs!
|
| 228 |
+
|
| 229 |
+
# You can read more about how the values in `TrainConfig` are chosen and adjusted [in this blog post on Hugging Face](https://huggingface.co/blog/dreambooth).
|
| 230 |
+
# To run training on images of your own pet, upload the images to separate URLs and edit the contents of the file at `TrainConfig.instance_example_urls_file` to point to them.
|
| 231 |
+
|
| 232 |
+
# Tip: if the results you're seeing don't match the prompt too well, and instead produce an image
|
| 233 |
+
# of your subject without taking the prompt into account, the model has likely overfit. In this case, repeat training with a lower
|
| 234 |
+
# value of `max_train_steps`. If you used W&B, look back at results earlier in training to determine where to stop.
|
| 235 |
+
# On the other hand, if the results don't look like your subject, you might need to increase `max_train_steps`.
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@dataclass
|
| 239 |
+
class TrainConfig(SharedConfig):
|
| 240 |
+
"""Configuration for the finetuning step."""
|
| 241 |
+
|
| 242 |
+
# training prompt looks like `{PREFIX} {INSTANCE_NAME} the {CLASS_NAME} {POSTFIX}`
|
| 243 |
+
prefix: str = "a photo of"
|
| 244 |
+
postfix: str = ""
|
| 245 |
+
|
| 246 |
+
# locator for plaintext file with urls for images of target instance
|
| 247 |
+
instance_example_urls_file: str = str(
|
| 248 |
+
Path(__file__).parent / "instance_example_urls.txt"
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# Hyperparameters/constants from the huggingface training example
|
| 252 |
+
resolution: int = 512
|
| 253 |
+
train_batch_size: int = 3
|
| 254 |
+
rank: int = 16 # lora rank
|
| 255 |
+
gradient_accumulation_steps: int = 1
|
| 256 |
+
learning_rate: float = 4e-4
|
| 257 |
+
lr_scheduler: str = "constant"
|
| 258 |
+
lr_warmup_steps: int = 0
|
| 259 |
+
max_train_steps: int = 500
|
| 260 |
+
checkpointing_steps: int = 1000
|
| 261 |
+
seed: int = 117
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@app.function(
|
| 265 |
+
image=image,
|
| 266 |
+
gpu="A100-80GB", # fine-tuning is VRAM-heavy and requires a high-VRAM GPU
|
| 267 |
+
volumes={MODEL_DIR: volume}, # stores fine-tuned model
|
| 268 |
+
timeout=1800, # 30 minutes
|
| 269 |
+
secrets=[huggingface_secret]
|
| 270 |
+
+ (
|
| 271 |
+
[modal.Secret.from_name("wandb-secret", required_keys=["WANDB_API_KEY"])]
|
| 272 |
+
if USE_WANDB
|
| 273 |
+
else []
|
| 274 |
+
),
|
| 275 |
+
)
|
| 276 |
+
def train(instance_example_urls, config):
|
| 277 |
+
import subprocess
|
| 278 |
+
|
| 279 |
+
from accelerate.utils import write_basic_config
|
| 280 |
+
|
| 281 |
+
# load data locally
|
| 282 |
+
img_path = load_images(instance_example_urls)
|
| 283 |
+
|
| 284 |
+
# set up hugging face accelerate library for fast training
|
| 285 |
+
write_basic_config(mixed_precision="bf16")
|
| 286 |
+
|
| 287 |
+
# define the training prompt
|
| 288 |
+
instance_phrase = f"{config.instance_name} the {config.class_name}"
|
| 289 |
+
prompt = f"{config.prefix} {instance_phrase} {config.postfix}".strip()
|
| 290 |
+
|
| 291 |
+
# the model training is packaged as a script, so we have to execute it as a subprocess, which adds some boilerplate
|
| 292 |
+
def _exec_subprocess(cmd: list[str]):
|
| 293 |
+
"""Executes subprocess and prints log to terminal while subprocess is running."""
|
| 294 |
+
process = subprocess.Popen(
|
| 295 |
+
cmd,
|
| 296 |
+
stdout=subprocess.PIPE,
|
| 297 |
+
stderr=subprocess.STDOUT,
|
| 298 |
+
)
|
| 299 |
+
with process.stdout as pipe:
|
| 300 |
+
for line in iter(pipe.readline, b""):
|
| 301 |
+
line_str = line.decode()
|
| 302 |
+
print(f"{line_str}", end="")
|
| 303 |
+
|
| 304 |
+
if exitcode := process.wait() != 0:
|
| 305 |
+
raise subprocess.CalledProcessError(exitcode, "\n".join(cmd))
|
| 306 |
+
|
| 307 |
+
# run training -- see huggingface accelerate docs for details
|
| 308 |
+
print("launching dreambooth training script")
|
| 309 |
+
_exec_subprocess(
|
| 310 |
+
[
|
| 311 |
+
"accelerate",
|
| 312 |
+
"launch",
|
| 313 |
+
"examples/dreambooth/train_dreambooth_lora_flux.py",
|
| 314 |
+
"--mixed_precision=bf16", # half-precision floats most of the time for faster training
|
| 315 |
+
f"--pretrained_model_name_or_path={MODEL_DIR}",
|
| 316 |
+
f"--instance_data_dir={img_path}",
|
| 317 |
+
f"--output_dir={MODEL_DIR}",
|
| 318 |
+
f"--instance_prompt={prompt}",
|
| 319 |
+
f"--resolution={config.resolution}",
|
| 320 |
+
f"--train_batch_size={config.train_batch_size}",
|
| 321 |
+
f"--gradient_accumulation_steps={config.gradient_accumulation_steps}",
|
| 322 |
+
f"--learning_rate={config.learning_rate}",
|
| 323 |
+
f"--lr_scheduler={config.lr_scheduler}",
|
| 324 |
+
f"--lr_warmup_steps={config.lr_warmup_steps}",
|
| 325 |
+
f"--max_train_steps={config.max_train_steps}",
|
| 326 |
+
f"--checkpointing_steps={config.checkpointing_steps}",
|
| 327 |
+
f"--seed={config.seed}", # increased reproducibility by seeding the RNG
|
| 328 |
+
]
|
| 329 |
+
+ (
|
| 330 |
+
[
|
| 331 |
+
"--report_to=wandb",
|
| 332 |
+
# validation output tracking is useful, but currently broken for Flux LoRA training
|
| 333 |
+
# f"--validation_prompt={prompt} in space", # simple test prompt
|
| 334 |
+
# f"--validation_epochs={config.max_train_steps // 5}",
|
| 335 |
+
]
|
| 336 |
+
if USE_WANDB
|
| 337 |
+
else []
|
| 338 |
+
),
|
| 339 |
+
)
|
| 340 |
+
# The trained model information has been output to the volume mounted at `MODEL_DIR`.
|
| 341 |
+
# To persist this data for use in our web app, we 'commit' the changes
|
| 342 |
+
# to the volume.
|
| 343 |
+
volume.commit()
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# ## Running our model
|
| 347 |
+
|
| 348 |
+
# To generate images from prompts using our fine-tuned model, we define a Modal Function called `inference`.
|
| 349 |
+
|
| 350 |
+
# Naively, this would seem to be a bad fit for the flexible, serverless infrastructure of Modal:
|
| 351 |
+
# wouldn't you need to include the steps to load the model and spin it up in every function call?
|
| 352 |
+
|
| 353 |
+
# In order to initialize the model just once on container startup,
|
| 354 |
+
# we use Modal's [container lifecycle](https://modal.com/docs/guide/lifecycle-functions) features, which require the function to be part
|
| 355 |
+
# of a class. Note that the `modal.Volume` we saved the model to is mounted here as well,
|
| 356 |
+
# so that the fine-tuned model created by `train` is available to us.
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
@app.cls(image=image, gpu="A100", volumes={MODEL_DIR: volume})
|
| 360 |
+
class Model:
|
| 361 |
+
@modal.enter()
|
| 362 |
+
def load_model(self):
|
| 363 |
+
import torch
|
| 364 |
+
from diffusers import DiffusionPipeline
|
| 365 |
+
|
| 366 |
+
# Reload the modal.Volume to ensure the latest state is accessible.
|
| 367 |
+
volume.reload()
|
| 368 |
+
|
| 369 |
+
# set up a hugging face inference pipeline using our model
|
| 370 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 371 |
+
MODEL_DIR,
|
| 372 |
+
torch_dtype=torch.bfloat16,
|
| 373 |
+
).to("cuda")
|
| 374 |
+
pipe.load_lora_weights(MODEL_DIR)
|
| 375 |
+
self.pipe = pipe
|
| 376 |
+
|
| 377 |
+
@modal.method()
|
| 378 |
+
def inference(self, text, config):
|
| 379 |
+
image = self.pipe(
|
| 380 |
+
text,
|
| 381 |
+
num_inference_steps=config.num_inference_steps,
|
| 382 |
+
guidance_scale=config.guidance_scale,
|
| 383 |
+
).images[0]
|
| 384 |
+
|
| 385 |
+
return image
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
# ## Wrap the trained model in a Gradio web UI
|
| 389 |
+
|
| 390 |
+
# [Gradio](https://gradio.app) makes it super easy to expose a model's functionality
|
| 391 |
+
# in an easy-to-use, responsive web interface.
|
| 392 |
+
|
| 393 |
+
# This model is a text-to-image generator,
|
| 394 |
+
# so we set up an interface that includes a user-entry text box
|
| 395 |
+
# and a frame for displaying images.
|
| 396 |
+
|
| 397 |
+
# We also provide some example text inputs to help
|
| 398 |
+
# guide users and to kick-start their creative juices.
|
| 399 |
+
|
| 400 |
+
# And we couldn't resist adding some Modal style to it as well!
|
| 401 |
+
|
| 402 |
+
# You can deploy the app on Modal with the command
|
| 403 |
+
# `modal deploy dreambooth_app.py`.
|
| 404 |
+
# You'll be able to come back days, weeks, or months later and find it still ready to go,
|
| 405 |
+
# even though you don't have to pay for a server to run while you're not using it.
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
@dataclass
|
| 409 |
+
class AppConfig(SharedConfig):
|
| 410 |
+
"""Configuration information for inference."""
|
| 411 |
+
|
| 412 |
+
num_inference_steps: int = 50
|
| 413 |
+
guidance_scale: float = 6
|
| 414 |
+
|
| 415 |
+
web_image = image.add_local_dir(
|
| 416 |
+
# Add local web assets to the image
|
| 417 |
+
Path(__file__).parent / "assets",
|
| 418 |
+
remote_path="/assets",
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
@app.function(
|
| 423 |
+
image=web_image,
|
| 424 |
+
max_containers=1,
|
| 425 |
+
)
|
| 426 |
+
@modal.concurrent(max_inputs=1000)
|
| 427 |
+
@modal.asgi_app()
|
| 428 |
+
def fastapi_app():
|
| 429 |
+
import gradio as gr
|
| 430 |
+
from fastapi import FastAPI
|
| 431 |
+
from fastapi.responses import FileResponse
|
| 432 |
+
from gradio.routes import mount_gradio_app
|
| 433 |
+
|
| 434 |
+
web_app = FastAPI()
|
| 435 |
+
|
| 436 |
+
# Call out to the inference in a separate Modal environment with a GPU
|
| 437 |
+
def go(text=""):
|
| 438 |
+
if not text:
|
| 439 |
+
text = example_prompts[0]
|
| 440 |
+
return Model().inference.remote(text, config)
|
| 441 |
+
|
| 442 |
+
# set up AppConfig
|
| 443 |
+
config = AppConfig()
|
| 444 |
+
|
| 445 |
+
instance_phrase = f"{config.instance_name} the {config.class_name}"
|
| 446 |
+
|
| 447 |
+
example_prompts = [
|
| 448 |
+
f"{instance_phrase}",
|
| 449 |
+
f"a painting of {instance_phrase.title()} With A Pearl Earring, by Vermeer",
|
| 450 |
+
f"oil painting of {instance_phrase} flying through space as an astronaut",
|
| 451 |
+
f"a painting of {instance_phrase} in cyberpunk city. character design by cory loftis. volumetric light, detailed, rendered in octane",
|
| 452 |
+
f"drawing of {instance_phrase} high quality, cartoon, path traced, by studio ghibli and don bluth",
|
| 453 |
+
]
|
| 454 |
+
|
| 455 |
+
modal_docs_url = "https://modal.com/docs"
|
| 456 |
+
modal_example_url = f"{modal_docs_url}/examples/dreambooth_app"
|
| 457 |
+
|
| 458 |
+
description = f"""Describe what they are doing or how a particular artist or style would depict them. Be fantastical! Try the examples below for inspiration.
|
| 459 |
+
|
| 460 |
+
### Learn how to make a "Dreambooth" for your own pet [here]({modal_example_url}).
|
| 461 |
+
"""
|
| 462 |
+
|
| 463 |
+
# custom styles: an icon, a background, and a theme
|
| 464 |
+
@web_app.get("/favicon.ico", include_in_schema=False)
|
| 465 |
+
async def favicon():
|
| 466 |
+
return FileResponse("/assets/favicon.svg")
|
| 467 |
+
|
| 468 |
+
@web_app.get("/assets/background.svg", include_in_schema=False)
|
| 469 |
+
async def background():
|
| 470 |
+
return FileResponse("/assets/background.svg")
|
| 471 |
+
|
| 472 |
+
with open("/assets/index.css") as f:
|
| 473 |
+
css = f.read()
|
| 474 |
+
|
| 475 |
+
theme = gr.themes.Default(
|
| 476 |
+
primary_hue="green", secondary_hue="emerald", neutral_hue="neutral"
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
# add a gradio UI around inference
|
| 480 |
+
with gr.Blocks(
|
| 481 |
+
theme=theme,
|
| 482 |
+
css=css,
|
| 483 |
+
title=f"Generate images of {config.instance_name} on Modal",
|
| 484 |
+
) as interface:
|
| 485 |
+
gr.Markdown(
|
| 486 |
+
f"# Generate images of {instance_phrase}.\n\n{description}",
|
| 487 |
+
)
|
| 488 |
+
with gr.Row():
|
| 489 |
+
inp = gr.Textbox( # input text component
|
| 490 |
+
label="",
|
| 491 |
+
placeholder=f"Describe the version of {instance_phrase} you'd like to see",
|
| 492 |
+
lines=10,
|
| 493 |
+
)
|
| 494 |
+
out = gr.Image( # output image component
|
| 495 |
+
height=512, width=512, label="", min_width=512, elem_id="output"
|
| 496 |
+
)
|
| 497 |
+
with gr.Row():
|
| 498 |
+
btn = gr.Button("Dream", variant="primary", scale=2)
|
| 499 |
+
btn.click(
|
| 500 |
+
fn=go, inputs=inp, outputs=out
|
| 501 |
+
) # connect inputs and outputs with inference function
|
| 502 |
+
|
| 503 |
+
gr.Button( # shameless plug
|
| 504 |
+
"⚡️ Powered by Modal",
|
| 505 |
+
variant="secondary",
|
| 506 |
+
link="https://modal.com",
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
with gr.Column(variant="compact"):
|
| 510 |
+
# add in a few examples to inspire users
|
| 511 |
+
for ii, prompt in enumerate(example_prompts):
|
| 512 |
+
btn = gr.Button(prompt, variant="secondary")
|
| 513 |
+
btn.click(fn=lambda idx=ii: example_prompts[idx], outputs=inp)
|
| 514 |
+
|
| 515 |
+
# mount for execution on Modal
|
| 516 |
+
return mount_gradio_app(
|
| 517 |
+
app=web_app,
|
| 518 |
+
blocks=interface,
|
| 519 |
+
path="/",
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
# ## Running your fine-tuned model from the command line
|
| 524 |
+
|
| 525 |
+
# You can use the `modal` command-line interface to set up, customize, and deploy this app:
|
| 526 |
+
|
| 527 |
+
# - `modal run diffusers_lora_finetune.py` will train the model. Change the `instance_example_urls_file` to point to your own pet's images.
|
| 528 |
+
# - `modal serve diffusers_lora_finetune.py` will [serve](https://modal.com/docs/guide/webhooks#developing-with-modal-serve) the Gradio interface at a temporary location. Great for iterating on code!
|
| 529 |
+
# - `modal shell diffusers_lora_finetune.py` is a convenient helper to open a bash [shell](https://modal.com/docs/guide/developing-debugging#interactive-shell) in our image. Great for debugging environment issues.
|
| 530 |
+
|
| 531 |
+
# Remember, once you've trained your own fine-tuned model, you can deploy it permanently -- for no cost when it is not being used! --
|
| 532 |
+
# using `modal deploy diffusers_lora_finetune.py`.
|
| 533 |
+
|
| 534 |
+
# If you just want to try the app out, you can find our deployment [here](https://modal-labs--example-lora-flux-fastapi-app.modal.run).
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
@app.local_entrypoint()
|
| 538 |
+
def run( # add more config params here to make training configurable
|
| 539 |
+
max_train_steps: int = 250,
|
| 540 |
+
):
|
| 541 |
+
print("🎨 loading model")
|
| 542 |
+
download_models.remote(SharedConfig())
|
| 543 |
+
print("🎨 setting up training")
|
| 544 |
+
config = TrainConfig(max_train_steps=max_train_steps)
|
| 545 |
+
instance_example_urls = (
|
| 546 |
+
Path(TrainConfig.instance_example_urls_file).read_text().splitlines()
|
| 547 |
+
)
|
| 548 |
+
train.remote(instance_example_urls, config)
|
| 549 |
+
print("🎨 training finished")
|
instance_example_urls.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
https://modal-public-assets.s3.amazonaws.com/example-dreambooth-app/fkRYgv6.png
|
| 2 |
+
https://modal-public-assets.s3.amazonaws.com/example-dreambooth-app/98k9yDg.jpg
|
| 3 |
+
https://modal-public-assets.s3.amazonaws.com/example-dreambooth-app/gHlW8Kw.jpg
|