Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,788 Bytes
642f8f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import torch
import os
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from aligner import ConceptAligner
from text_encoder import LoraT5Embedder
from pipeline import CustomFluxKontextPipeline
from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
from peft import LoraConfig
import gradio as gr
# Configuration
MODEL_REPO = "Shaoan/ConceptAligner-Weights" # Your model repo
CHECKPOINT_DIR = "./checkpoint"
def download_checkpoint():
"""Download checkpoint files from HF model repo"""
print("Downloading checkpoint files...")
files = [
"model.safetensors",
"model_1.safetensors",
"model_2.safetensors"
]
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
for filename in files:
local_path = os.path.join(CHECKPOINT_DIR, filename)
if not os.path.exists(local_path):
print(f" Downloading {filename}...")
hf_hub_download(
repo_id=MODEL_REPO,
filename=filename,
local_dir=CHECKPOINT_DIR,
local_dir_use_symlinks=False
)
print(f" β {filename} downloaded")
print("β All checkpoint files ready!")
class ConceptAlignerModel:
def __init__(self):
# Download checkpoint first
download_checkpoint()
self.checkpoint_path = CHECKPOINT_DIR
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
self.previous_image = None
self.previous_prompt = None
print(f"\n{'='*60}")
print(f"Loading ConceptAligner Model")
print(f"Device: {self.device}")
print(f"{'='*60}")
self.setup_models()
def setup_models(self):
"""Load all models"""
# Load ConceptAligner
print(f" Loading ConceptAligner...")
self.model = ConceptAligner().to(self.device).to(self.dtype)
adapter_path = os.path.join(self.checkpoint_path, "model_1.safetensors")
adapter_state = load_file(adapter_path)
self.model.load_state_dict(adapter_state, strict=True)
print(f" β Adapter loaded")
# Load T5 encoder
print(f" Loading T5 encoder...")
self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
adapter_path = os.path.join(self.checkpoint_path, "model_2.safetensors")
adapter_state = load_file(adapter_path)
if "t5_encoder.shared.weight" in adapter_state and "t5_encoder.encoder.embed_tokens.weight" not in adapter_state:
adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
self.text_encoder.load_state_dict(adapter_state, strict=True)
print(f" β T5 Adapter loaded")
# Load VAE
print(f" Loading VAE...")
vae = AutoencoderKL.from_pretrained(
'black-forest-labs/FLUX.1-dev',
subfolder="vae",
torch_dtype=self.dtype
).to(self.device)
# Load transformer
print(f" Loading transformer...")
transformer = FluxTransformer2DModel.from_pretrained(
'black-forest-labs/FLUX.1-dev',
subfolder="transformer",
torch_dtype=self.dtype
)
target_modules = [
"attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
"attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
"ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
"proj_mlp", "proj_out", "norm.linear", "norm1.linear"
]
transformer_lora_config = LoraConfig(
r=256,
lora_alpha=256,
lora_dropout=0.0,
init_lora_weights="gaussian",
target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)
transformer.context_embedder.requires_grad_(True)
# Load fine-tuned transformer
transformer_path = os.path.join(self.checkpoint_path, "model.safetensors")
transformer_state = load_file(transformer_path)
transformer.load_state_dict(transformer_state, strict=True)
print(f" β Fine-tuned transformer loaded")
transformer = transformer.to(self.device)
# Load or download empty pooled clip
empty_clip_path = "empty_pooled_clip.pt"
if not os.path.exists(empty_clip_path):
print(" Downloading empty_pooled_clip.pt...")
hf_hub_download(
repo_id=MODEL_REPO,
filename="empty_pooled_clip.pt",
local_dir=".",
local_dir_use_symlinks=False
)
self.empty_pooled_clip = torch.load(empty_clip_path, map_location=self.device).to(self.dtype)
# Create pipeline
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
'black-forest-labs/FLUX.1-dev', subfolder="scheduler"
)
self.pipe = CustomFluxKontextPipeline(
scheduler=noise_scheduler,
aligner=self.model.to(self.device).to(self.dtype),
transformer=transformer.to(self.device).to(self.dtype),
vae=vae.to(self.device).to(self.dtype),
text_embedder=self.text_encoder.to(self.device).to(self.dtype),
).to(self.device)
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated(0) / 1024**3
reserved = torch.cuda.memory_reserved(0) / 1024**3
print(f" β Pipeline ready on {self.device}")
print(f" π GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
else:
print(f" β Pipeline ready on {self.device}")
@torch.no_grad()
def generate_image(
self,
prompt,
threshold=0.0,
topk=0,
height=512,
width=512,
guidance_scale=3.5,
true_cf_scale=1.0,
num_inference_steps=20,
seed=1995
):
"""Generate image and return previous + current for comparison"""
if not prompt.strip():
return self.previous_image, None, self.previous_prompt or ""
try:
generator = torch.Generator(device=self.device).manual_seed(int(seed))
current_image = self.pipe(
prompt=prompt,
guidance_scale=guidance_scale,
true_cfg_scale=true_cf_scale,
max_sequence_length=512,
num_inference_steps=num_inference_steps,
height=height,
width=width,
generator=generator,
).images[0]
prev_image = self.previous_image
prev_prompt = self.previous_prompt or "No previous generation"
self.previous_image = current_image
self.previous_prompt = prompt
return prev_image, current_image, prev_prompt
except Exception as e:
import traceback
error_msg = f"β Error: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return self.previous_image, None, self.previous_prompt or ""
def reset_history(self):
"""Clear generation history"""
self.previous_image = None
self.previous_prompt = None
return None, None, "No previous generation"
# Initialize model
print("Initializing ConceptAligner model...")
model = ConceptAlignerModel()
|