ConceptAligner / app.py
Shaoan's picture
Upload folder using huggingface_hub
36f6af4 verified
raw
history blame
7.43 kB
"""
ConceptAligner Hugging Face Demo
Downloads weights from model repo at startup
"""
import torch
import gradio as gr
import os
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from aligner import ConceptAligner
from text_encoder import LoraT5Embedder
from pipeline import CustomFluxKontextPipeline
from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
from peft import LoraConfig
# Configuration
MODEL_REPO = "Shaoan/ConceptAligner-Weights"
CHECKPOINT_DIR = "./checkpoint"
EXAMPLE_PROMPTS = [
[
"""In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""]
]
def download_checkpoint():
"""Download checkpoint files from HF model repo"""
print("Downloading checkpoint files...")
files = ["model.safetensors", "model_1.safetensors", "model_2.safetensors", "empty_pooled_clip.pt"]
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
for filename in files:
local_path = os.path.join(CHECKPOINT_DIR, filename)
if not os.path.exists(local_path):
print(f" Downloading {filename}...")
hf_hub_download(
repo_id=MODEL_REPO,
filename=filename,
local_dir=CHECKPOINT_DIR,
local_dir_use_symlinks=False
)
print("βœ“ All files ready!")
class ConceptAlignerModel:
def __init__(self):
download_checkpoint()
self.checkpoint_path = CHECKPOINT_DIR
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
self.previous_image = None
self.previous_prompt = None
self.setup_models()
def setup_models(self):
"""Load all models"""
print(f"Loading models on {self.device}...")
# Load ConceptAligner
self.model = ConceptAligner().to(self.device).to(self.dtype)
adapter_state = load_file(os.path.join(self.checkpoint_path, "model_1.safetensors"))
self.model.load_state_dict(adapter_state, strict=True)
# Load T5 encoder
self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
adapter_state = load_file(os.path.join(self.checkpoint_path, "model_2.safetensors"))
if "t5_encoder.shared.weight" in adapter_state:
adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
self.text_encoder.load_state_dict(adapter_state, strict=True)
# Load VAE
vae = AutoencoderKL.from_pretrained(
'black-forest-labs/FLUX.1-dev', subfolder="vae", torch_dtype=self.dtype
).to(self.device)
# Load transformer
transformer = FluxTransformer2DModel.from_pretrained(
'black-forest-labs/FLUX.1-dev', subfolder="transformer", torch_dtype=self.dtype
)
transformer_lora_config = LoraConfig(
r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
target_modules=[
"attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
"attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
"ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
"proj_mlp", "proj_out", "norm.linear", "norm1.linear"
],
)
transformer.add_adapter(transformer_lora_config)
transformer.context_embedder.requires_grad_(True)
transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
transformer.load_state_dict(transformer_state, strict=True)
transformer = transformer.to(self.device)
# Load empty pooled clip
self.empty_pooled_clip = torch.load(
os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"),
map_location=self.device
).to(self.dtype)
# Create pipeline
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
'black-forest-labs/FLUX.1-dev', subfolder="scheduler"
)
self.pipe = CustomFluxKontextPipeline(
scheduler=noise_scheduler,
aligner=self.model,
transformer=transformer,
vae=vae,
text_embedder=self.text_encoder,
).to(self.device)
print("βœ“ Model loaded!")
@torch.no_grad()
def generate_image(self, prompt, threshold=0.0, topk=0, height=512, width=512,
guidance_scale=3.5, true_cf_scale=1.0, num_inference_steps=20, seed=1995):
if not prompt.strip():
return self.previous_image, None, self.previous_prompt or ""
try:
generator = torch.Generator(device=self.device).manual_seed(int(seed))
current_image = self.pipe(
prompt=prompt, guidance_scale=guidance_scale, true_cfg_scale=true_cf_scale,
max_sequence_length=512, num_inference_steps=num_inference_steps,
height=height, width=width, generator=generator,
).images[0]
prev_image = self.previous_image
prev_prompt = self.previous_prompt or "No previous generation"
self.previous_image = current_image
self.previous_prompt = prompt
return prev_image, current_image, prev_prompt
except Exception as e:
print(f"Error: {e}")
return self.previous_image, None, self.previous_prompt or ""
def reset_history(self):
self.previous_image = None
self.previous_prompt = None
return None, None, "No previous generation"
# Initialize model
print("Initializing ConceptAligner...")
model = ConceptAlignerModel()
# Create Gradio interface
with gr.Blocks(title="ConceptAligner", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎨 ConceptAligner Demo\nGenerate images with fine-tuned concept alignment!")
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(label="Prompt", lines=6, placeholder="Describe your image...")
with gr.Row():
generate_btn = gr.Button("✨ Generate", variant="primary", size="lg", scale=3)
reset_btn = gr.Button("πŸ”„ Reset", variant="secondary", size="lg", scale=1)
with gr.Accordion("βš™οΈ Settings", open=True):
guidance_scale = gr.Slider(1.0, 10.0, value=3.5, step=0.5, label="Guidance Scale")
num_steps = gr.Slider(10, 50, value=20, step=1, label="Steps")
seed = gr.Number(value=0, label="Seed", precision=0)
with gr.Accordion("πŸ”¬ Advanced", open=False):
true_cfg_scale = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="True CFG")
threshold = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Threshold")
topk = gr.Slider(0, 300, value=0, step=1, label="Top-K")
with gr.Row():
height = gr.Slider(256, 1024, value=512, step=64, label="Height")
width = gr.Slider(256, 1024, value=512, step=64, label="Width")
with gr.Column(scale=2):
gr.Markdown("### πŸ“Š Comparison View")
with gr.Row():
with gr.Column():
gr.Markdown("**Previous**")
prev_image = gr.Image(label="Previous", type="pil", height=450)
prev_prompt_display = gr.Textbox(label="Previous Prompt", lines=3, interactive=False)
with gr.Column():
gr.Markdown("**Current**")
current_image = gr.Image(label="Current", type="pil", height=450)
gr.Markdown("### πŸ“ Example")
gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input)
generate_btn.click(
fn=model.generate_image,
inputs=[prompt_input, threshold, topk, height, width, guidance_scale, true_cfg_scale, num_steps, seed],
outputs=[prev_image, current_image, prev_prompt_display]
)
reset_btn.click(fn=model.reset_history, outputs=[prev_image, current_image, prev_prompt_display])
if __name__ == "__main__":
demo.launch()