Update app.py
Browse files
app.py
CHANGED
|
@@ -1,12 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import random
|
| 4 |
import spaces
|
| 5 |
import torch
|
| 6 |
-
from
|
|
|
|
|
|
|
| 7 |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
| 8 |
|
| 9 |
-
from model import Flux
|
| 10 |
|
| 11 |
def calculate_shift(
|
| 12 |
image_seq_len,
|
|
@@ -174,20 +179,17 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
|
|
| 174 |
@dataclass
|
| 175 |
class ModelSpec:
|
| 176 |
params: FluxParams
|
| 177 |
-
ae_params: AutoEncoderParams
|
| 178 |
-
ckpt_path: str
|
| 179 |
-
ae_path: str
|
| 180 |
repo_id: str
|
| 181 |
repo_flow: str
|
| 182 |
repo_ae: str
|
| 183 |
repo_id_ae: str
|
| 184 |
|
|
|
|
| 185 |
config = ModelSpec(
|
| 186 |
repo_id="TencentARC/flux-mini",
|
| 187 |
repo_flow="flux-mini.safetensors",
|
| 188 |
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
| 189 |
repo_ae="ae.safetensors",
|
| 190 |
-
ckpt_path=os.getenv("FLUX_MINI", None),
|
| 191 |
params=FluxParams(
|
| 192 |
in_channels=64,
|
| 193 |
vec_in_dim=768,
|
|
@@ -202,35 +204,33 @@ config = ModelSpec(
|
|
| 202 |
qkv_bias=True,
|
| 203 |
guidance_embed=True,
|
| 204 |
)
|
|
|
|
| 205 |
|
| 206 |
|
| 207 |
-
def load_flow_model2(device: str = "cuda", hf_download: bool = True):
|
| 208 |
-
if (
|
| 209 |
-
and config.repo_id is not None
|
| 210 |
and config.repo_flow is not None
|
| 211 |
and hf_download
|
| 212 |
):
|
| 213 |
-
ckpt_path = hf_hub_download(
|
| 214 |
|
| 215 |
-
model = Flux(params)
|
| 216 |
if ckpt_path is not None:
|
| 217 |
sd = load_sft(ckpt_path, device=str(device))
|
| 218 |
missing, unexpected = model.load_state_dict(sd, strict=True)
|
| 219 |
return model
|
| 220 |
|
| 221 |
|
| 222 |
-
|
| 223 |
-
|
| 224 |
dtype = torch.bfloat16
|
| 225 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 226 |
|
| 227 |
-
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler")
|
| 228 |
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
|
| 229 |
text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder").to(device)
|
| 230 |
-
tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer")
|
| 231 |
text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2").to(device)
|
| 232 |
-
tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2")
|
| 233 |
-
transformer = load_flow_model2(device)
|
| 234 |
|
| 235 |
pipe = FluxPipeline(
|
| 236 |
scheduler,
|
|
@@ -238,7 +238,7 @@ pipe = FluxPipeline(
|
|
| 238 |
text_encoder,
|
| 239 |
tokenizer,
|
| 240 |
text_encoder_2,
|
| 241 |
-
tokenizer_2
|
| 242 |
transformer
|
| 243 |
)
|
| 244 |
torch.cuda.empty_cache()
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Union, Optional, List, Any, Dict
|
| 3 |
+
|
| 4 |
import gradio as gr
|
| 5 |
import numpy as np
|
| 6 |
import random
|
| 7 |
import spaces
|
| 8 |
import torch
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
|
| 11 |
+
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, FluxPipeline
|
| 12 |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
| 13 |
|
| 14 |
+
from model import Flux, FluxParams
|
| 15 |
|
| 16 |
def calculate_shift(
|
| 17 |
image_seq_len,
|
|
|
|
| 179 |
@dataclass
|
| 180 |
class ModelSpec:
|
| 181 |
params: FluxParams
|
|
|
|
|
|
|
|
|
|
| 182 |
repo_id: str
|
| 183 |
repo_flow: str
|
| 184 |
repo_ae: str
|
| 185 |
repo_id_ae: str
|
| 186 |
|
| 187 |
+
|
| 188 |
config = ModelSpec(
|
| 189 |
repo_id="TencentARC/flux-mini",
|
| 190 |
repo_flow="flux-mini.safetensors",
|
| 191 |
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
| 192 |
repo_ae="ae.safetensors",
|
|
|
|
| 193 |
params=FluxParams(
|
| 194 |
in_channels=64,
|
| 195 |
vec_in_dim=768,
|
|
|
|
| 204 |
qkv_bias=True,
|
| 205 |
guidance_embed=True,
|
| 206 |
)
|
| 207 |
+
)
|
| 208 |
|
| 209 |
|
| 210 |
+
def load_flow_model2(config, device: str = "cuda", hf_download: bool = True):
|
| 211 |
+
if (config.repo_id is not None
|
|
|
|
| 212 |
and config.repo_flow is not None
|
| 213 |
and hf_download
|
| 214 |
):
|
| 215 |
+
ckpt_path = hf_hub_download(config.repo_id, config.repo_flow.replace("sft", "safetensors"))
|
| 216 |
|
| 217 |
+
model = Flux(config.params)
|
| 218 |
if ckpt_path is not None:
|
| 219 |
sd = load_sft(ckpt_path, device=str(device))
|
| 220 |
missing, unexpected = model.load_state_dict(sd, strict=True)
|
| 221 |
return model
|
| 222 |
|
| 223 |
|
|
|
|
|
|
|
| 224 |
dtype = torch.bfloat16
|
| 225 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 226 |
|
| 227 |
+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler")
|
| 228 |
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
|
| 229 |
text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder").to(device)
|
| 230 |
+
tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer")
|
| 231 |
text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2").to(device)
|
| 232 |
+
tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2")
|
| 233 |
+
transformer = load_flow_model2(config, device)
|
| 234 |
|
| 235 |
pipe = FluxPipeline(
|
| 236 |
scheduler,
|
|
|
|
| 238 |
text_encoder,
|
| 239 |
tokenizer,
|
| 240 |
text_encoder_2,
|
| 241 |
+
tokenizer_2,
|
| 242 |
transformer
|
| 243 |
)
|
| 244 |
torch.cuda.empty_cache()
|