""" ConceptAligner - Same GPU behavior as FLUX demo Models loaded at startup, GPU allocated only for inference """ # CRITICAL: Import spaces FIRST import spaces import torch import gradio as gr import os from huggingface_hub import hf_hub_download, login from safetensors.torch import load_file from aligner import ConceptAligner from text_encoder import LoraT5Embedder from pipeline import CustomFluxKontextPipeline from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL from peft import LoraConfig # Login HF_TOKEN = os.environ.get("HF_TOKEN") if HF_TOKEN: login(token=HF_TOKEN) print("✓ Logged in to Hugging Face") # Configuration MODEL_REPO = "Shaoan/ConceptAligner-Weights" CHECKPOINT_DIR = "./checkpoint" dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" EXAMPLE_PROMPTS = [ ["""In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""] ] def download_checkpoint(): """Download checkpoint files""" print("Downloading checkpoint files...") files = ["model.safetensors", "model_1.safetensors", "model_2.safetensors", "empty_pooled_clip.pt"] os.makedirs(CHECKPOINT_DIR, exist_ok=True) for filename in files: local_path = os.path.join(CHECKPOINT_DIR, filename) if not os.path.exists(local_path): print(f" Downloading {filename}...") hf_hub_download( repo_id=MODEL_REPO, filename=filename, local_dir=CHECKPOINT_DIR, token=HF_TOKEN ) print("✓ Checkpoint files ready!") # Download at startup download_checkpoint() # Load models at startup (like FLUX does) print("Loading models...") # Load ConceptAligner aligner_model = ConceptAligner().to(device).to(dtype) adapter_state = load_file(os.path.join(CHECKPOINT_DIR, "model_1.safetensors")) aligner_model.load_state_dict(adapter_state, strict=True) print(" ✓ ConceptAligner") # Load T5 encoder text_encoder = LoraT5Embedder(device=device).to(dtype) adapter_state = load_file(os.path.join(CHECKPOINT_DIR, "model_2.safetensors")) if "t5_encoder.shared.weight" in adapter_state: adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"] text_encoder.load_state_dict(adapter_state, strict=True) print(" ✓ T5 Encoder") # Load VAE vae = AutoencoderKL.from_pretrained( 'black-forest-labs/FLUX.1-dev', subfolder="vae", torch_dtype=dtype, token=HF_TOKEN ).to(device) print(" ✓ VAE") # Load transformer config = FluxTransformer2DModel.load_config( 'black-forest-labs/FLUX.1-dev', subfolder="transformer", token=HF_TOKEN ) transformer = FluxTransformer2DModel.from_config(config, torch_dtype=dtype) transformer_lora_config = LoraConfig( r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian", target_modules=[ "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0", "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out", "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2", "proj_mlp", "proj_out", "norm.linear", "norm1.linear" ], ) transformer.add_adapter(transformer_lora_config) transformer.context_embedder.requires_grad_(True) transformer_state = load_file(os.path.join(CHECKPOINT_DIR, "model.safetensors")) transformer.load_state_dict(transformer_state, strict=False) transformer = transformer.to(device).to(dtype) print(" ✓ Transformer") # Load scheduler noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 'black-forest-labs/FLUX.1-dev', subfolder="scheduler", token=HF_TOKEN ) # Create pipeline pipe = CustomFluxKontextPipeline( scheduler=noise_scheduler, aligner=aligner_model, transformer=transformer, vae=vae, text_embedder=text_encoder, ).to(device) print("✅ Models loaded and ready!") torch.cuda.empty_cache() # History tracking previous_image = None previous_prompt = None @spaces.GPU(duration=75) @torch.no_grad() def generate_image(prompt, height=512, width=512, guidance_scale=3.5, true_cf_scale=1.0, num_inference_steps=20, seed=0, progress=gr.Progress(track_tqdm=True)): """Generate image - models already loaded""" global previous_image, previous_prompt if not prompt.strip(): return previous_image, None, previous_prompt or "No previous generation", seed try: generator = torch.Generator(device=device).manual_seed(int(seed)) current_image = pipe( prompt=prompt, guidance_scale=guidance_scale, true_cfg_scale=true_cf_scale, max_sequence_length=512, num_inference_steps=num_inference_steps, height=height, width=width, generator=generator, ).images[0] # Store for comparison prev_image = previous_image prev_prompt = previous_prompt or "No previous generation" previous_image = current_image previous_prompt = prompt return prev_image, current_image, prev_prompt, seed except Exception as e: import traceback print(f"❌ Error: {e}") print(traceback.format_exc()) return previous_image, None, previous_prompt or "", seed def reset_history(): """Clear generation history""" global previous_image, previous_prompt previous_image = None previous_prompt = None return None, None, "No previous generation" # Create Gradio interface css = """ #col-container { margin: 0 auto; max-width: 1400px; } """ with gr.Blocks(css=css, title="ConceptAligner") as demo: with gr.Column(elem_id="col-container"): gr.Markdown(""" # 🎨 ConceptAligner Image Generator Create stunning AI-generated images from text descriptions. """) with gr.Row(): with gr.Column(scale=1): prompt_input = gr.Textbox( label="Prompt", lines=8, placeholder="Describe your image in detail...", ) with gr.Row(): generate_btn = gr.Button("✨ Generate", variant="primary", scale=3) reset_btn = gr.Button("🔄 Clear History", variant="secondary", scale=1) with gr.Accordion("⚙️ Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=2147483647, step=1, value=0, ) guidance_scale = gr.Slider( label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.5, value=3.5, info="Higher = follows prompt more closely (3-4 recommended)" ) num_inference_steps = gr.Slider( label="Number of Steps", minimum=10, maximum=50, step=1, value=20, info="More steps = higher quality but slower" ) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=1024, step=64, value=512, ) height = gr.Slider( label="Height", minimum=256, maximum=1024, step=64, value=512, ) true_cfg_scale = gr.Slider( label="True CFG Scale", minimum=1.0, maximum=10.0, step=0.5, value=1.0, visible=False ) with gr.Column(scale=2): gr.Markdown("### 📊 Your Generations") with gr.Row(): with gr.Column(): gr.Markdown("**Previous**") prev_image = gr.Image(label="Previous", show_label=False, type="pil", height=450) prev_prompt_display = gr.Textbox( label="Previous Prompt", lines=3, interactive=False, show_label=False ) with gr.Column(): gr.Markdown("**Latest**") current_image = gr.Image(label="Current", show_label=False, type="pil", height=450) gr.Markdown("### 📝 Try This Example") gr.Examples( examples=EXAMPLE_PROMPTS, inputs=prompt_input, outputs=[prev_image, current_image, prev_prompt_display, seed], fn=generate_image, cache_examples=False ) # Event handlers gr.on( triggers=[generate_btn.click, prompt_input.submit], fn=generate_image, inputs=[prompt_input, height, width, guidance_scale, true_cfg_scale, num_inference_steps, seed], outputs=[prev_image, current_image, prev_prompt_display, seed] ) reset_btn.click( fn=reset_history, outputs=[prev_image, current_image, prev_prompt_display] ) if __name__ == "__main__": demo.launch()