""" ConceptAligner Hugging Face Demo Downloads weights from model repo at startup """ import torch import gradio as gr import os from huggingface_hub import hf_hub_download from safetensors.torch import load_file from aligner import ConceptAligner from text_encoder import LoraT5Embedder from pipeline import CustomFluxKontextPipeline from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL from peft import LoraConfig # Configuration MODEL_REPO = "Shaoan/ConceptAligner-Weights" CHECKPOINT_DIR = "./checkpoint" EXAMPLE_PROMPTS = [ [ """In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""] ] def download_checkpoint(): """Download checkpoint files from HF model repo""" print("Downloading checkpoint files...") files = ["model.safetensors", "model_1.safetensors", "model_2.safetensors", "empty_pooled_clip.pt"] os.makedirs(CHECKPOINT_DIR, exist_ok=True) for filename in files: local_path = os.path.join(CHECKPOINT_DIR, filename) if not os.path.exists(local_path): print(f" Downloading {filename}...") hf_hub_download( repo_id=MODEL_REPO, filename=filename, local_dir=CHECKPOINT_DIR, local_dir_use_symlinks=False ) print("✓ All files ready!") class ConceptAlignerModel: def __init__(self): download_checkpoint() self.checkpoint_path = CHECKPOINT_DIR self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 self.previous_image = None self.previous_prompt = None self.setup_models() def setup_models(self): """Load all models""" print(f"Loading models on {self.device}...") # Load ConceptAligner self.model = ConceptAligner().to(self.device).to(self.dtype) adapter_state = load_file(os.path.join(self.checkpoint_path, "model_1.safetensors")) self.model.load_state_dict(adapter_state, strict=True) # Load T5 encoder self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype) adapter_state = load_file(os.path.join(self.checkpoint_path, "model_2.safetensors")) if "t5_encoder.shared.weight" in adapter_state: adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"] self.text_encoder.load_state_dict(adapter_state, strict=True) # Load VAE vae = AutoencoderKL.from_pretrained( 'black-forest-labs/FLUX.1-dev', subfolder="vae", torch_dtype=self.dtype ).to(self.device) # Load transformer transformer = FluxTransformer2DModel.from_pretrained( 'black-forest-labs/FLUX.1-dev', subfolder="transformer", torch_dtype=self.dtype ) transformer_lora_config = LoraConfig( r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian", target_modules=[ "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0", "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out", "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2", "proj_mlp", "proj_out", "norm.linear", "norm1.linear" ], ) transformer.add_adapter(transformer_lora_config) transformer.context_embedder.requires_grad_(True) transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors")) transformer.load_state_dict(transformer_state, strict=True) transformer = transformer.to(self.device) # Load empty pooled clip self.empty_pooled_clip = torch.load( os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"), map_location=self.device ).to(self.dtype) # Create pipeline noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 'black-forest-labs/FLUX.1-dev', subfolder="scheduler" ) self.pipe = CustomFluxKontextPipeline( scheduler=noise_scheduler, aligner=self.model, transformer=transformer, vae=vae, text_embedder=self.text_encoder, ).to(self.device) print("✓ Model loaded!") @torch.no_grad() def generate_image(self, prompt, threshold=0.0, topk=0, height=512, width=512, guidance_scale=3.5, true_cf_scale=1.0, num_inference_steps=20, seed=1995): if not prompt.strip(): return self.previous_image, None, self.previous_prompt or "" try: generator = torch.Generator(device=self.device).manual_seed(int(seed)) current_image = self.pipe( prompt=prompt, guidance_scale=guidance_scale, true_cfg_scale=true_cf_scale, max_sequence_length=512, num_inference_steps=num_inference_steps, height=height, width=width, generator=generator, ).images[0] prev_image = self.previous_image prev_prompt = self.previous_prompt or "No previous generation" self.previous_image = current_image self.previous_prompt = prompt return prev_image, current_image, prev_prompt except Exception as e: print(f"Error: {e}") return self.previous_image, None, self.previous_prompt or "" def reset_history(self): self.previous_image = None self.previous_prompt = None return None, None, "No previous generation" # Initialize model print("Initializing ConceptAligner...") model = ConceptAlignerModel() # Create Gradio interface with gr.Blocks(title="ConceptAligner", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎨 ConceptAligner Demo\nGenerate images with fine-tuned concept alignment!") with gr.Row(): with gr.Column(scale=1): prompt_input = gr.Textbox(label="Prompt", lines=6, placeholder="Describe your image...") with gr.Row(): generate_btn = gr.Button("✨ Generate", variant="primary", size="lg", scale=3) reset_btn = gr.Button("🔄 Reset", variant="secondary", size="lg", scale=1) with gr.Accordion("⚙️ Settings", open=True): guidance_scale = gr.Slider(1.0, 10.0, value=3.5, step=0.5, label="Guidance Scale") num_steps = gr.Slider(10, 50, value=20, step=1, label="Steps") seed = gr.Number(value=0, label="Seed", precision=0) with gr.Accordion("🔬 Advanced", open=False): true_cfg_scale = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="True CFG") threshold = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Threshold") topk = gr.Slider(0, 300, value=0, step=1, label="Top-K") with gr.Row(): height = gr.Slider(256, 1024, value=512, step=64, label="Height") width = gr.Slider(256, 1024, value=512, step=64, label="Width") with gr.Column(scale=2): gr.Markdown("### 📊 Comparison View") with gr.Row(): with gr.Column(): gr.Markdown("**Previous**") prev_image = gr.Image(label="Previous", type="pil", height=450) prev_prompt_display = gr.Textbox(label="Previous Prompt", lines=3, interactive=False) with gr.Column(): gr.Markdown("**Current**") current_image = gr.Image(label="Current", type="pil", height=450) gr.Markdown("### 📝 Example") gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input) generate_btn.click( fn=model.generate_image, inputs=[prompt_input, threshold, topk, height, width, guidance_scale, true_cfg_scale, num_steps, seed], outputs=[prev_image, current_image, prev_prompt_display] ) reset_btn.click(fn=model.reset_history, outputs=[prev_image, current_image, prev_prompt_display]) if __name__ == "__main__": demo.launch()