ConceptAligner / app.py
Shaoan's picture
Upload folder using huggingface_hub
5c693f5 verified
"""
ConceptAligner - Same GPU behavior as FLUX demo
Models loaded at startup, GPU allocated only for inference
"""
# CRITICAL: Import spaces FIRST
import spaces
import torch
import gradio as gr
import os
from huggingface_hub import hf_hub_download, login
from safetensors.torch import load_file
from aligner import ConceptAligner
from text_encoder import LoraT5Embedder
from pipeline import CustomFluxKontextPipeline
from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
from peft import LoraConfig
# Login
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
print("βœ“ Logged in to Hugging Face")
# Configuration
MODEL_REPO = "Shaoan/ConceptAligner-Weights"
CHECKPOINT_DIR = "./checkpoint"
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
EXAMPLE_PROMPTS = [
["""In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""]
]
def download_checkpoint():
"""Download checkpoint files"""
print("Downloading checkpoint files...")
files = ["model.safetensors", "model_1.safetensors", "model_2.safetensors", "empty_pooled_clip.pt"]
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
for filename in files:
local_path = os.path.join(CHECKPOINT_DIR, filename)
if not os.path.exists(local_path):
print(f" Downloading {filename}...")
hf_hub_download(
repo_id=MODEL_REPO,
filename=filename,
local_dir=CHECKPOINT_DIR,
token=HF_TOKEN
)
print("βœ“ Checkpoint files ready!")
# Download at startup
download_checkpoint()
# Load models at startup (like FLUX does)
print("Loading models...")
# Load ConceptAligner
aligner_model = ConceptAligner().to(device).to(dtype)
adapter_state = load_file(os.path.join(CHECKPOINT_DIR, "model_1.safetensors"))
aligner_model.load_state_dict(adapter_state, strict=True)
print(" βœ“ ConceptAligner")
# Load T5 encoder
text_encoder = LoraT5Embedder(device=device).to(dtype)
adapter_state = load_file(os.path.join(CHECKPOINT_DIR, "model_2.safetensors"))
if "t5_encoder.shared.weight" in adapter_state:
adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
text_encoder.load_state_dict(adapter_state, strict=True)
print(" βœ“ T5 Encoder")
# Load VAE
vae = AutoencoderKL.from_pretrained(
'black-forest-labs/FLUX.1-dev',
subfolder="vae",
torch_dtype=dtype,
token=HF_TOKEN
).to(device)
print(" βœ“ VAE")
# Load transformer
config = FluxTransformer2DModel.load_config(
'black-forest-labs/FLUX.1-dev',
subfolder="transformer",
token=HF_TOKEN
)
transformer = FluxTransformer2DModel.from_config(config, torch_dtype=dtype)
transformer_lora_config = LoraConfig(
r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
target_modules=[
"attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
"attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
"ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
"proj_mlp", "proj_out", "norm.linear", "norm1.linear"
],
)
transformer.add_adapter(transformer_lora_config)
transformer.context_embedder.requires_grad_(True)
transformer_state = load_file(os.path.join(CHECKPOINT_DIR, "model.safetensors"))
transformer.load_state_dict(transformer_state, strict=False)
transformer = transformer.to(device).to(dtype)
print(" βœ“ Transformer")
# Load scheduler
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
'black-forest-labs/FLUX.1-dev',
subfolder="scheduler",
token=HF_TOKEN
)
# Create pipeline
pipe = CustomFluxKontextPipeline(
scheduler=noise_scheduler,
aligner=aligner_model,
transformer=transformer,
vae=vae,
text_embedder=text_encoder,
).to(device)
print("βœ… Models loaded and ready!")
torch.cuda.empty_cache()
# History tracking
previous_image = None
previous_prompt = None
@spaces.GPU(duration=75)
@torch.no_grad()
def generate_image(prompt, height=512, width=512, guidance_scale=3.5,
true_cf_scale=1.0, num_inference_steps=20, seed=0,
progress=gr.Progress(track_tqdm=True)):
"""Generate image - models already loaded"""
global previous_image, previous_prompt
if not prompt.strip():
return previous_image, None, previous_prompt or "No previous generation", seed
try:
generator = torch.Generator(device=device).manual_seed(int(seed))
current_image = pipe(
prompt=prompt,
guidance_scale=guidance_scale,
true_cfg_scale=true_cf_scale,
max_sequence_length=512,
num_inference_steps=num_inference_steps,
height=height,
width=width,
generator=generator,
).images[0]
# Store for comparison
prev_image = previous_image
prev_prompt = previous_prompt or "No previous generation"
previous_image = current_image
previous_prompt = prompt
return prev_image, current_image, prev_prompt, seed
except Exception as e:
import traceback
print(f"❌ Error: {e}")
print(traceback.format_exc())
return previous_image, None, previous_prompt or "", seed
def reset_history():
"""Clear generation history"""
global previous_image, previous_prompt
previous_image = None
previous_prompt = None
return None, None, "No previous generation"
# Create Gradio interface
css = """
#col-container {
margin: 0 auto;
max-width: 1400px;
}
"""
with gr.Blocks(css=css, title="ConceptAligner") as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""
# 🎨 ConceptAligner Image Generator
Create stunning AI-generated images from text descriptions.
""")
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(
label="Prompt",
lines=8,
placeholder="Describe your image in detail...",
)
with gr.Row():
generate_btn = gr.Button("✨ Generate", variant="primary", scale=3)
reset_btn = gr.Button("πŸ”„ Clear History", variant="secondary", scale=1)
with gr.Accordion("βš™οΈ Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2147483647,
step=1,
value=0,
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1.0,
maximum=10.0,
step=0.5,
value=3.5,
info="Higher = follows prompt more closely (3-4 recommended)"
)
num_inference_steps = gr.Slider(
label="Number of Steps",
minimum=10,
maximum=50,
step=1,
value=20,
info="More steps = higher quality but slower"
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=1024,
step=64,
value=512,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=1024,
step=64,
value=512,
)
true_cfg_scale = gr.Slider(
label="True CFG Scale",
minimum=1.0,
maximum=10.0,
step=0.5,
value=1.0,
visible=False
)
with gr.Column(scale=2):
gr.Markdown("### πŸ“Š Your Generations")
with gr.Row():
with gr.Column():
gr.Markdown("**Previous**")
prev_image = gr.Image(label="Previous", show_label=False, type="pil", height=450)
prev_prompt_display = gr.Textbox(
label="Previous Prompt",
lines=3,
interactive=False,
show_label=False
)
with gr.Column():
gr.Markdown("**Latest**")
current_image = gr.Image(label="Current", show_label=False, type="pil", height=450)
gr.Markdown("### πŸ“ Try This Example")
gr.Examples(
examples=EXAMPLE_PROMPTS,
inputs=prompt_input,
outputs=[prev_image, current_image, prev_prompt_display, seed],
fn=generate_image,
cache_examples=False
)
# Event handlers
gr.on(
triggers=[generate_btn.click, prompt_input.submit],
fn=generate_image,
inputs=[prompt_input, height, width, guidance_scale, true_cfg_scale, num_inference_steps, seed],
outputs=[prev_image, current_image, prev_prompt_display, seed]
)
reset_btn.click(
fn=reset_history,
outputs=[prev_image, current_image, prev_prompt_display]
)
if __name__ == "__main__":
demo.launch()