Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| ConceptAligner - Same GPU behavior as FLUX demo | |
| Models loaded at startup, GPU allocated only for inference | |
| """ | |
| # CRITICAL: Import spaces FIRST | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| import os | |
| from huggingface_hub import hf_hub_download, login | |
| from safetensors.torch import load_file | |
| from aligner import ConceptAligner | |
| from text_encoder import LoraT5Embedder | |
| from pipeline import CustomFluxKontextPipeline | |
| from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL | |
| from peft import LoraConfig | |
| # Login | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if HF_TOKEN: | |
| login(token=HF_TOKEN) | |
| print("β Logged in to Hugging Face") | |
| # Configuration | |
| MODEL_REPO = "Shaoan/ConceptAligner-Weights" | |
| CHECKPOINT_DIR = "./checkpoint" | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| EXAMPLE_PROMPTS = [ | |
| ["""In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""] | |
| ] | |
| def download_checkpoint(): | |
| """Download checkpoint files""" | |
| print("Downloading checkpoint files...") | |
| files = ["model.safetensors", "model_1.safetensors", "model_2.safetensors", "empty_pooled_clip.pt"] | |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) | |
| for filename in files: | |
| local_path = os.path.join(CHECKPOINT_DIR, filename) | |
| if not os.path.exists(local_path): | |
| print(f" Downloading {filename}...") | |
| hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=filename, | |
| local_dir=CHECKPOINT_DIR, | |
| token=HF_TOKEN | |
| ) | |
| print("β Checkpoint files ready!") | |
| # Download at startup | |
| download_checkpoint() | |
| # Load models at startup (like FLUX does) | |
| print("Loading models...") | |
| # Load ConceptAligner | |
| aligner_model = ConceptAligner().to(device).to(dtype) | |
| adapter_state = load_file(os.path.join(CHECKPOINT_DIR, "model_1.safetensors")) | |
| aligner_model.load_state_dict(adapter_state, strict=True) | |
| print(" β ConceptAligner") | |
| # Load T5 encoder | |
| text_encoder = LoraT5Embedder(device=device).to(dtype) | |
| adapter_state = load_file(os.path.join(CHECKPOINT_DIR, "model_2.safetensors")) | |
| if "t5_encoder.shared.weight" in adapter_state: | |
| adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"] | |
| text_encoder.load_state_dict(adapter_state, strict=True) | |
| print(" β T5 Encoder") | |
| # Load VAE | |
| vae = AutoencoderKL.from_pretrained( | |
| 'black-forest-labs/FLUX.1-dev', | |
| subfolder="vae", | |
| torch_dtype=dtype, | |
| token=HF_TOKEN | |
| ).to(device) | |
| print(" β VAE") | |
| # Load transformer | |
| config = FluxTransformer2DModel.load_config( | |
| 'black-forest-labs/FLUX.1-dev', | |
| subfolder="transformer", | |
| token=HF_TOKEN | |
| ) | |
| transformer = FluxTransformer2DModel.from_config(config, torch_dtype=dtype) | |
| transformer_lora_config = LoraConfig( | |
| r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian", | |
| target_modules=[ | |
| "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0", | |
| "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out", | |
| "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2", | |
| "proj_mlp", "proj_out", "norm.linear", "norm1.linear" | |
| ], | |
| ) | |
| transformer.add_adapter(transformer_lora_config) | |
| transformer.context_embedder.requires_grad_(True) | |
| transformer_state = load_file(os.path.join(CHECKPOINT_DIR, "model.safetensors")) | |
| transformer.load_state_dict(transformer_state, strict=False) | |
| transformer = transformer.to(device).to(dtype) | |
| print(" β Transformer") | |
| # Load scheduler | |
| noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| 'black-forest-labs/FLUX.1-dev', | |
| subfolder="scheduler", | |
| token=HF_TOKEN | |
| ) | |
| # Create pipeline | |
| pipe = CustomFluxKontextPipeline( | |
| scheduler=noise_scheduler, | |
| aligner=aligner_model, | |
| transformer=transformer, | |
| vae=vae, | |
| text_embedder=text_encoder, | |
| ).to(device) | |
| print("β Models loaded and ready!") | |
| torch.cuda.empty_cache() | |
| # History tracking | |
| previous_image = None | |
| previous_prompt = None | |
| def generate_image(prompt, height=512, width=512, guidance_scale=3.5, | |
| true_cf_scale=1.0, num_inference_steps=20, seed=0, | |
| progress=gr.Progress(track_tqdm=True)): | |
| """Generate image - models already loaded""" | |
| global previous_image, previous_prompt | |
| if not prompt.strip(): | |
| return previous_image, None, previous_prompt or "No previous generation", seed | |
| try: | |
| generator = torch.Generator(device=device).manual_seed(int(seed)) | |
| current_image = pipe( | |
| prompt=prompt, | |
| guidance_scale=guidance_scale, | |
| true_cfg_scale=true_cf_scale, | |
| max_sequence_length=512, | |
| num_inference_steps=num_inference_steps, | |
| height=height, | |
| width=width, | |
| generator=generator, | |
| ).images[0] | |
| # Store for comparison | |
| prev_image = previous_image | |
| prev_prompt = previous_prompt or "No previous generation" | |
| previous_image = current_image | |
| previous_prompt = prompt | |
| return prev_image, current_image, prev_prompt, seed | |
| except Exception as e: | |
| import traceback | |
| print(f"β Error: {e}") | |
| print(traceback.format_exc()) | |
| return previous_image, None, previous_prompt or "", seed | |
| def reset_history(): | |
| """Clear generation history""" | |
| global previous_image, previous_prompt | |
| previous_image = None | |
| previous_prompt = None | |
| return None, None, "No previous generation" | |
| # Create Gradio interface | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1400px; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="ConceptAligner") as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown(""" | |
| # π¨ ConceptAligner Image Generator | |
| Create stunning AI-generated images from text descriptions. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| lines=8, | |
| placeholder="Describe your image in detail...", | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("β¨ Generate", variant="primary", scale=3) | |
| reset_btn = gr.Button("π Clear History", variant="secondary", scale=1) | |
| with gr.Accordion("βοΈ Settings", open=False): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=2147483647, | |
| step=1, | |
| value=0, | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| step=0.5, | |
| value=3.5, | |
| info="Higher = follows prompt more closely (3-4 recommended)" | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of Steps", | |
| minimum=10, | |
| maximum=50, | |
| step=1, | |
| value=20, | |
| info="More steps = higher quality but slower" | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=1024, | |
| step=64, | |
| value=512, | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=1024, | |
| step=64, | |
| value=512, | |
| ) | |
| true_cfg_scale = gr.Slider( | |
| label="True CFG Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| step=0.5, | |
| value=1.0, | |
| visible=False | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π Your Generations") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("**Previous**") | |
| prev_image = gr.Image(label="Previous", show_label=False, type="pil", height=450) | |
| prev_prompt_display = gr.Textbox( | |
| label="Previous Prompt", | |
| lines=3, | |
| interactive=False, | |
| show_label=False | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("**Latest**") | |
| current_image = gr.Image(label="Current", show_label=False, type="pil", height=450) | |
| gr.Markdown("### π Try This Example") | |
| gr.Examples( | |
| examples=EXAMPLE_PROMPTS, | |
| inputs=prompt_input, | |
| outputs=[prev_image, current_image, prev_prompt_display, seed], | |
| fn=generate_image, | |
| cache_examples=False | |
| ) | |
| # Event handlers | |
| gr.on( | |
| triggers=[generate_btn.click, prompt_input.submit], | |
| fn=generate_image, | |
| inputs=[prompt_input, height, width, guidance_scale, true_cfg_scale, num_inference_steps, seed], | |
| outputs=[prev_image, current_image, prev_prompt_display, seed] | |
| ) | |
| reset_btn.click( | |
| fn=reset_history, | |
| outputs=[prev_image, current_image, prev_prompt_display] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |