Spaces:
Running on Zero
Running on Zero
File size: 3,250 Bytes
a8a9bce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | import logging
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
import torch
import numpy as np
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from flowdis.autoencoder import AutoEncoder
from flowdis.conditioner import HFEmbedder
from flowdis.configs import configs
from flowdis.loaders import load_autoencoder, load_clip, load_t5, load_transformer
from flowdis.model import Flux
logger = logging.getLogger(__name__)
@dataclass
class Models:
clip: HFEmbedder
t5: HFEmbedder
ae: AutoEncoder
transformer: Flux
def load_models(
root_model_dir: Path = None,
device: str | torch.device = "cuda"
) -> Models:
"""
Load the models for the FlowDIS pipeline.
Args:
root_model_dir: The root model directory.
If None, the models are downloaded from the Hugging Face Hub.
device: The device to load the models on.
Returns:
Models: The loaded models.
"""
if root_model_dir is None:
root_model_dir = download_from_hf_hub("PAIR/FlowDIS")
logger.info("Loading T5.")
t5 = load_t5(
model_path=root_model_dir / "t5-v1_1-xxl" / "model.safetensors",
device=device,
max_length=512
)
logger.info("Loading CLIP.")
clip = load_clip(
model_path=root_model_dir / "clip-vit-large-patch14" / "model.safetensors",
device=device
)
logger.info("Loading AE.")
ae = load_autoencoder(
model_path=root_model_dir / "ae.safetensors",
device=device
)
logger.info("Loading Transformer.")
model = load_transformer(
model_name="flowdis",
model_path=root_model_dir / "flowdis-transformer.safetensors",
device=device,
)
logger.info("All models loaded.")
return Models(
clip=clip,
t5=t5,
ae=ae,
transformer=model,
)
def download_from_hf_hub(
repo_id: str,
cache_dir: str | Path | None = None,
revision: str | None = None,
) -> Path:
"""
Download a FlowDIS model repository from the Hugging Face Hub.
Args:
repo_id: The Hugging Face Hub repo id (e.g. "PAIR/FlowDIS").
cache_dir: Optional cache directory. Defaults to the huggingface_hub
default (typically ~/.cache/huggingface/hub).
revision: Optional git revision (branch, tag, or commit SHA).
Returns:
Path to the local directory containing the downloaded snapshot. The
directory layout matches the repo layout on the Hub, so it can be
passed directly to `load_models` as `root_model_dir`.
"""
logger.info(f"Downloading {repo_id} from Hugging Face Hub.")
local_dir = snapshot_download(
repo_id=repo_id,
cache_dir=cache_dir,
revision=revision,
)
logger.info(f"Snapshot available at {local_dir}.")
return Path(local_dir)
def green_screen(img: np.ndarray, mask: np.ndarray) -> np.ndarray:
img_np = np.array(img)
mask = (np.array(mask) / 255)[:, :, np.newaxis].repeat(3, axis=2)
combined = img_np * mask + (1-mask) * np.array([0, 255, 0], dtype=np.uint8)
combined = combined.astype(np.uint8)
return combined
|