Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -82,7 +82,7 @@ class ConceptAlignerModel:
|
|
| 82 |
self.text_encoder.load_state_dict(adapter_state, strict=True)
|
| 83 |
print(" β T5 encoder loaded")
|
| 84 |
|
| 85 |
-
# Only download VAE (small ~
|
| 86 |
print(" Loading VAE from FLUX.1-dev...")
|
| 87 |
vae = AutoencoderKL.from_pretrained(
|
| 88 |
'black-forest-labs/FLUX.1-dev',
|
|
@@ -90,25 +90,23 @@ class ConceptAlignerModel:
|
|
| 90 |
torch_dtype=self.dtype,
|
| 91 |
token=HF_TOKEN
|
| 92 |
).to(self.device)
|
| 93 |
-
print(" β VAE loaded
|
| 94 |
|
| 95 |
-
# Create transformer
|
| 96 |
-
print("
|
| 97 |
-
|
| 98 |
-
# Get config only (no weights download)
|
| 99 |
-
from diffusers.models.transformers.transformer_flux import FluxTransformerConfig
|
| 100 |
config = FluxTransformer2DModel.load_config(
|
| 101 |
'black-forest-labs/FLUX.1-dev',
|
| 102 |
subfolder="transformer",
|
| 103 |
token=HF_TOKEN
|
| 104 |
)
|
| 105 |
|
| 106 |
-
# Initialize
|
| 107 |
-
transformer
|
| 108 |
-
|
|
|
|
| 109 |
|
| 110 |
-
# Add LoRA config
|
| 111 |
-
print(" Adding LoRA adapter
|
| 112 |
transformer_lora_config = LoraConfig(
|
| 113 |
r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
|
| 114 |
target_modules=[
|
|
@@ -120,15 +118,25 @@ class ConceptAlignerModel:
|
|
| 120 |
)
|
| 121 |
transformer.add_adapter(transformer_lora_config)
|
| 122 |
transformer.context_embedder.requires_grad_(True)
|
|
|
|
| 123 |
|
| 124 |
-
# Load YOUR FULL fine-tuned transformer weights
|
| 125 |
-
print(" Loading
|
| 126 |
transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
transformer = transformer.to(self.device).to(self.dtype)
|
| 129 |
-
print(" β Fine-tuned transformer loaded
|
| 130 |
|
| 131 |
# Load empty pooled clip
|
|
|
|
| 132 |
self.empty_pooled_clip = torch.load(
|
| 133 |
os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"),
|
| 134 |
map_location=self.device,
|
|
@@ -136,8 +144,8 @@ class ConceptAlignerModel:
|
|
| 136 |
).to(self.dtype)
|
| 137 |
print(" β Empty pooled clip loaded")
|
| 138 |
|
| 139 |
-
# Create scheduler (just config
|
| 140 |
-
print(" Loading scheduler
|
| 141 |
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 142 |
'black-forest-labs/FLUX.1-dev',
|
| 143 |
subfolder="scheduler",
|
|
@@ -146,7 +154,7 @@ class ConceptAlignerModel:
|
|
| 146 |
print(" β Scheduler loaded")
|
| 147 |
|
| 148 |
# Create pipeline
|
| 149 |
-
print("
|
| 150 |
self.pipe = CustomFluxKontextPipeline(
|
| 151 |
scheduler=noise_scheduler,
|
| 152 |
aligner=self.model,
|
|
@@ -156,10 +164,8 @@ class ConceptAlignerModel:
|
|
| 156 |
).to(self.device)
|
| 157 |
|
| 158 |
print("="*60)
|
| 159 |
-
print("
|
| 160 |
print("="*60)
|
| 161 |
-
print(f"Total downloads: ~330MB VAE + ~26GB your checkpoints")
|
| 162 |
-
print(f"Saved: ~24GB by not downloading base FLUX transformer!")
|
| 163 |
|
| 164 |
# Print memory usage
|
| 165 |
if torch.cuda.is_available():
|
|
@@ -211,7 +217,6 @@ with gr.Blocks(title="ConceptAligner", theme=gr.themes.Soft()) as demo:
|
|
| 211 |
# π¨ ConceptAligner Demo
|
| 212 |
|
| 213 |
Generate images with fine-tuned concept alignment using FLUX!
|
| 214 |
-
This demo uses fully fine-tuned weights - no base model downloads needed.
|
| 215 |
""")
|
| 216 |
|
| 217 |
with gr.Row():
|
|
|
|
| 82 |
self.text_encoder.load_state_dict(adapter_state, strict=True)
|
| 83 |
print(" β T5 encoder loaded")
|
| 84 |
|
| 85 |
+
# Only download VAE (small ~168MB)
|
| 86 |
print(" Loading VAE from FLUX.1-dev...")
|
| 87 |
vae = AutoencoderKL.from_pretrained(
|
| 88 |
'black-forest-labs/FLUX.1-dev',
|
|
|
|
| 90 |
torch_dtype=self.dtype,
|
| 91 |
token=HF_TOKEN
|
| 92 |
).to(self.device)
|
| 93 |
+
print(" β VAE loaded")
|
| 94 |
|
| 95 |
+
# Create transformer from config only (download config.json but not weights)
|
| 96 |
+
print(" Downloading transformer config only...")
|
|
|
|
|
|
|
|
|
|
| 97 |
config = FluxTransformer2DModel.load_config(
|
| 98 |
'black-forest-labs/FLUX.1-dev',
|
| 99 |
subfolder="transformer",
|
| 100 |
token=HF_TOKEN
|
| 101 |
)
|
| 102 |
|
| 103 |
+
# Initialize transformer from config (no weights)
|
| 104 |
+
print(" Initializing transformer architecture from config...")
|
| 105 |
+
transformer = FluxTransformer2DModel.from_config(config, torch_dtype=self.dtype)
|
| 106 |
+
print(" β Empty transformer initialized")
|
| 107 |
|
| 108 |
+
# Add LoRA adapter config
|
| 109 |
+
print(" Adding LoRA adapter layers...")
|
| 110 |
transformer_lora_config = LoraConfig(
|
| 111 |
r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
|
| 112 |
target_modules=[
|
|
|
|
| 118 |
)
|
| 119 |
transformer.add_adapter(transformer_lora_config)
|
| 120 |
transformer.context_embedder.requires_grad_(True)
|
| 121 |
+
print(" β LoRA adapters added")
|
| 122 |
|
| 123 |
+
# Load YOUR FULL fine-tuned transformer weights
|
| 124 |
+
print(" Loading your fine-tuned transformer weights...")
|
| 125 |
transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
|
| 126 |
+
|
| 127 |
+
# Load with strict=False in case of minor key mismatches
|
| 128 |
+
missing_keys, unexpected_keys = transformer.load_state_dict(transformer_state, strict=False)
|
| 129 |
+
|
| 130 |
+
if missing_keys:
|
| 131 |
+
print(f" β οΈ Missing keys: {len(missing_keys)}")
|
| 132 |
+
if unexpected_keys:
|
| 133 |
+
print(f" β οΈ Unexpected keys: {len(unexpected_keys)}")
|
| 134 |
+
|
| 135 |
transformer = transformer.to(self.device).to(self.dtype)
|
| 136 |
+
print(" β Fine-tuned transformer loaded")
|
| 137 |
|
| 138 |
# Load empty pooled clip
|
| 139 |
+
print(" Loading empty pooled clip...")
|
| 140 |
self.empty_pooled_clip = torch.load(
|
| 141 |
os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"),
|
| 142 |
map_location=self.device,
|
|
|
|
| 144 |
).to(self.dtype)
|
| 145 |
print(" β Empty pooled clip loaded")
|
| 146 |
|
| 147 |
+
# Create scheduler (just config)
|
| 148 |
+
print(" Loading scheduler...")
|
| 149 |
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 150 |
'black-forest-labs/FLUX.1-dev',
|
| 151 |
subfolder="scheduler",
|
|
|
|
| 154 |
print(" β Scheduler loaded")
|
| 155 |
|
| 156 |
# Create pipeline
|
| 157 |
+
print(" Assembling pipeline...")
|
| 158 |
self.pipe = CustomFluxKontextPipeline(
|
| 159 |
scheduler=noise_scheduler,
|
| 160 |
aligner=self.model,
|
|
|
|
| 164 |
).to(self.device)
|
| 165 |
|
| 166 |
print("="*60)
|
| 167 |
+
print("β
ALL MODELS LOADED SUCCESSFULLY!")
|
| 168 |
print("="*60)
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Print memory usage
|
| 171 |
if torch.cuda.is_available():
|
|
|
|
| 217 |
# π¨ ConceptAligner Demo
|
| 218 |
|
| 219 |
Generate images with fine-tuned concept alignment using FLUX!
|
|
|
|
| 220 |
""")
|
| 221 |
|
| 222 |
with gr.Row():
|