ConceptAligner / app.py
Shaoan's picture
Upload folder using huggingface_hub
6af382a verified
raw
history blame
10.1 kB
"""
ConceptAligner Hugging Face Demo
"""
import torch
import gradio as gr
import os
from huggingface_hub import hf_hub_download, login
from safetensors.torch import load_file
from aligner import ConceptAligner
from text_encoder import LoraT5Embedder
from pipeline import CustomFluxKontextPipeline
from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
from peft import LoraConfig
# For HF Spaces GPU support
try:
import spaces
GPU_AVAILABLE = True
except ImportError:
GPU_AVAILABLE = False
print("⚠️ spaces package not available, running without @spaces.GPU decorator")
# Login with token from environment
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
print("βœ“ Logged in to Hugging Face")
# Configuration
MODEL_REPO = "Shaoan/ConceptAligner-Weights"
CHECKPOINT_DIR = "./checkpoint"
EXAMPLE_PROMPTS = [
["""In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""]
]
def download_checkpoint():
"""Download checkpoint files from HF model repo"""
print("Downloading checkpoint files...")
files = ["model.safetensors", "model_1.safetensors", "model_2.safetensors", "empty_pooled_clip.pt"]
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
for filename in files:
local_path = os.path.join(CHECKPOINT_DIR, filename)
if not os.path.exists(local_path):
print(f" Downloading {filename}...")
hf_hub_download(
repo_id=MODEL_REPO,
filename=filename,
local_dir=CHECKPOINT_DIR,
token=HF_TOKEN
)
print(f" βœ“ {filename} downloaded")
print("βœ“ All checkpoint files ready!")
class ConceptAlignerModel:
def __init__(self):
download_checkpoint()
self.checkpoint_path = CHECKPOINT_DIR
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
self.previous_image = None
self.previous_prompt = None
self.setup_models()
def setup_models(self):
"""Load all models"""
print(f"Loading models on {self.device}...")
# Load ConceptAligner
print(" Loading ConceptAligner...")
self.model = ConceptAligner().to(self.device).to(self.dtype)
adapter_state = load_file(os.path.join(self.checkpoint_path, "model_1.safetensors"))
self.model.load_state_dict(adapter_state, strict=True)
print(" βœ“ ConceptAligner loaded")
# Load T5 encoder
print(" Loading fine-tuned T5 encoder...")
self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
adapter_state = load_file(os.path.join(self.checkpoint_path, "model_2.safetensors"))
if "t5_encoder.shared.weight" in adapter_state:
adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
self.text_encoder.load_state_dict(adapter_state, strict=True)
print(" βœ“ T5 encoder loaded")
# Download VAE
print(" Loading VAE from FLUX.1-dev...")
vae = AutoencoderKL.from_pretrained(
'black-forest-labs/FLUX.1-dev',
subfolder="vae",
torch_dtype=self.dtype,
token=HF_TOKEN
).to(self.device)
print(" βœ“ VAE loaded")
# Create transformer from config
print(" Downloading transformer config...")
config = FluxTransformer2DModel.load_config(
'black-forest-labs/FLUX.1-dev',
subfolder="transformer",
token=HF_TOKEN
)
print(" Initializing transformer...")
transformer = FluxTransformer2DModel.from_config(config, torch_dtype=self.dtype)
print(" Adding LoRA adapters...")
transformer_lora_config = LoraConfig(
r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
target_modules=[
"attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
"attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
"ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
"proj_mlp", "proj_out", "norm.linear", "norm1.linear"
],
)
transformer.add_adapter(transformer_lora_config)
transformer.context_embedder.requires_grad_(True)
print(" Loading fine-tuned transformer weights...")
transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
transformer.load_state_dict(transformer_state, strict=False)
transformer = transformer.to(self.device).to(self.dtype)
print(" βœ“ Transformer loaded")
# Load empty pooled clip
self.empty_pooled_clip = torch.load(
os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"),
map_location=self.device,
weights_only=True
).to(self.dtype)
# Create scheduler
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
'black-forest-labs/FLUX.1-dev',
subfolder="scheduler",
token=HF_TOKEN
)
# Create pipeline
self.pipe = CustomFluxKontextPipeline(
scheduler=noise_scheduler,
aligner=self.model,
transformer=transformer,
vae=vae,
text_embedder=self.text_encoder,
).to(self.device)
print("βœ… ALL MODELS LOADED!")
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated(0) / 1024**3
print(f"πŸ“Š GPU Memory: {allocated:.2f}GB allocated")
@torch.no_grad()
def generate_image(self, prompt, threshold=0.0, topk=0, height=512, width=512,
guidance_scale=3.5, true_cf_scale=1.0, num_inference_steps=20, seed=1995):
if not prompt.strip():
return self.previous_image, None, self.previous_prompt or ""
try:
generator = torch.Generator(device=self.device).manual_seed(int(seed))
current_image = self.pipe(
prompt=prompt, guidance_scale=guidance_scale, true_cfg_scale=true_cf_scale,
max_sequence_length=512, num_inference_steps=num_inference_steps,
height=height, width=width, generator=generator,
).images[0]
prev_image = self.previous_image
prev_prompt = self.previous_prompt or "No previous generation"
self.previous_image = current_image
self.previous_prompt = prompt
return prev_image, current_image, prev_prompt
except Exception as e:
import traceback
print(f"❌ Error: {e}")
print(traceback.format_exc())
return self.previous_image, None, self.previous_prompt or ""
def reset_history(self):
self.previous_image = None
self.previous_prompt = None
return None, None, "No previous generation"
# Initialize model
print("πŸš€ Initializing ConceptAligner...")
model = ConceptAlignerModel()
# Wrap generation function with @spaces.GPU if available
if GPU_AVAILABLE:
generate_fn = spaces.GPU(model.generate_image)
else:
generate_fn = model.generate_image
# Create Gradio interface
with gr.Blocks(title="ConceptAligner") as demo:
gr.Markdown("# 🎨 ConceptAligner Demo\nGenerate images with fine-tuned concept alignment!")
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(label="Prompt", lines=6, placeholder="Describe your image...")
with gr.Row():
generate_btn = gr.Button("✨ Generate", variant="primary", size="lg", scale=3)
reset_btn = gr.Button("πŸ”„ Reset", variant="secondary", size="lg", scale=1)
with gr.Accordion("βš™οΈ Settings", open=True):
guidance_scale = gr.Slider(1.0, 10.0, value=3.5, step=0.5, label="Guidance Scale")
num_steps = gr.Slider(10, 50, value=20, step=1, label="Steps")
seed = gr.Number(value=0, label="Seed", precision=0)
with gr.Accordion("πŸ”¬ Advanced", open=False):
true_cfg_scale = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="True CFG")
threshold = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Threshold")
topk = gr.Slider(0, 300, value=0, step=1, label="Top-K")
with gr.Row():
height = gr.Slider(256, 1024, value=512, step=64, label="Height")
width = gr.Slider(256, 1024, value=512, step=64, label="Width")
with gr.Column(scale=2):
gr.Markdown("### πŸ“Š Comparison View")
with gr.Row():
with gr.Column():
gr.Markdown("**Previous**")
prev_image = gr.Image(label="Previous", type="pil", height=450)
prev_prompt_display = gr.Textbox(label="Previous Prompt", lines=3, interactive=False)
with gr.Column():
gr.Markdown("**Current**")
current_image = gr.Image(label="Current", type="pil", height=450)
gr.Markdown("### πŸ“ Example")
gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input)
generate_btn.click(
fn=generate_fn,
inputs=[prompt_input, threshold, topk, height, width, guidance_scale, true_cfg_scale, num_steps, seed],
outputs=[prev_image, current_image, prev_prompt_display]
)
reset_btn.click(fn=model.reset_history, outputs=[prev_image, current_image, prev_prompt_display])
if __name__ == "__main__":
demo.launch()