Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| ConceptAligner Hugging Face Demo | |
| Downloads weights from model repo at startup | |
| """ | |
| import torch | |
| import gradio as gr | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from aligner import ConceptAligner | |
| from text_encoder import LoraT5Embedder | |
| from pipeline import CustomFluxKontextPipeline | |
| from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL | |
| from peft import LoraConfig | |
| # Configuration | |
| MODEL_REPO = "Shaoan/ConceptAligner-Weights" | |
| CHECKPOINT_DIR = "./checkpoint" | |
| EXAMPLE_PROMPTS = [ | |
| [ | |
| """In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""] | |
| ] | |
| def download_checkpoint(): | |
| """Download checkpoint files from HF model repo""" | |
| print("Downloading checkpoint files...") | |
| files = ["model.safetensors", "model_1.safetensors", "model_2.safetensors", "empty_pooled_clip.pt"] | |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) | |
| for filename in files: | |
| local_path = os.path.join(CHECKPOINT_DIR, filename) | |
| if not os.path.exists(local_path): | |
| print(f" Downloading {filename}...") | |
| hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=filename, | |
| local_dir=CHECKPOINT_DIR, | |
| local_dir_use_symlinks=False | |
| ) | |
| print("β All files ready!") | |
| class ConceptAlignerModel: | |
| def __init__(self): | |
| download_checkpoint() | |
| self.checkpoint_path = CHECKPOINT_DIR | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| self.previous_image = None | |
| self.previous_prompt = None | |
| self.setup_models() | |
| def setup_models(self): | |
| """Load all models""" | |
| print(f"Loading models on {self.device}...") | |
| # Load ConceptAligner | |
| self.model = ConceptAligner().to(self.device).to(self.dtype) | |
| adapter_state = load_file(os.path.join(self.checkpoint_path, "model_1.safetensors")) | |
| self.model.load_state_dict(adapter_state, strict=True) | |
| # Load T5 encoder | |
| self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype) | |
| adapter_state = load_file(os.path.join(self.checkpoint_path, "model_2.safetensors")) | |
| if "t5_encoder.shared.weight" in adapter_state: | |
| adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"] | |
| self.text_encoder.load_state_dict(adapter_state, strict=True) | |
| # Load VAE | |
| vae = AutoencoderKL.from_pretrained( | |
| 'black-forest-labs/FLUX.1-dev', subfolder="vae", torch_dtype=self.dtype | |
| ).to(self.device) | |
| # Load transformer | |
| transformer = FluxTransformer2DModel.from_pretrained( | |
| 'black-forest-labs/FLUX.1-dev', subfolder="transformer", torch_dtype=self.dtype | |
| ) | |
| transformer_lora_config = LoraConfig( | |
| r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian", | |
| target_modules=[ | |
| "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0", | |
| "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out", | |
| "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2", | |
| "proj_mlp", "proj_out", "norm.linear", "norm1.linear" | |
| ], | |
| ) | |
| transformer.add_adapter(transformer_lora_config) | |
| transformer.context_embedder.requires_grad_(True) | |
| transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors")) | |
| transformer.load_state_dict(transformer_state, strict=True) | |
| transformer = transformer.to(self.device) | |
| # Load empty pooled clip | |
| self.empty_pooled_clip = torch.load( | |
| os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"), | |
| map_location=self.device | |
| ).to(self.dtype) | |
| # Create pipeline | |
| noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| 'black-forest-labs/FLUX.1-dev', subfolder="scheduler" | |
| ) | |
| self.pipe = CustomFluxKontextPipeline( | |
| scheduler=noise_scheduler, | |
| aligner=self.model, | |
| transformer=transformer, | |
| vae=vae, | |
| text_embedder=self.text_encoder, | |
| ).to(self.device) | |
| print("β Model loaded!") | |
| def generate_image(self, prompt, threshold=0.0, topk=0, height=512, width=512, | |
| guidance_scale=3.5, true_cf_scale=1.0, num_inference_steps=20, seed=1995): | |
| if not prompt.strip(): | |
| return self.previous_image, None, self.previous_prompt or "" | |
| try: | |
| generator = torch.Generator(device=self.device).manual_seed(int(seed)) | |
| current_image = self.pipe( | |
| prompt=prompt, guidance_scale=guidance_scale, true_cfg_scale=true_cf_scale, | |
| max_sequence_length=512, num_inference_steps=num_inference_steps, | |
| height=height, width=width, generator=generator, | |
| ).images[0] | |
| prev_image = self.previous_image | |
| prev_prompt = self.previous_prompt or "No previous generation" | |
| self.previous_image = current_image | |
| self.previous_prompt = prompt | |
| return prev_image, current_image, prev_prompt | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| return self.previous_image, None, self.previous_prompt or "" | |
| def reset_history(self): | |
| self.previous_image = None | |
| self.previous_prompt = None | |
| return None, None, "No previous generation" | |
| # Initialize model | |
| print("Initializing ConceptAligner...") | |
| model = ConceptAlignerModel() | |
| # Create Gradio interface | |
| with gr.Blocks(title="ConceptAligner", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π¨ ConceptAligner Demo\nGenerate images with fine-tuned concept alignment!") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox(label="Prompt", lines=6, placeholder="Describe your image...") | |
| with gr.Row(): | |
| generate_btn = gr.Button("β¨ Generate", variant="primary", size="lg", scale=3) | |
| reset_btn = gr.Button("π Reset", variant="secondary", size="lg", scale=1) | |
| with gr.Accordion("βοΈ Settings", open=True): | |
| guidance_scale = gr.Slider(1.0, 10.0, value=3.5, step=0.5, label="Guidance Scale") | |
| num_steps = gr.Slider(10, 50, value=20, step=1, label="Steps") | |
| seed = gr.Number(value=0, label="Seed", precision=0) | |
| with gr.Accordion("π¬ Advanced", open=False): | |
| true_cfg_scale = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="True CFG") | |
| threshold = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Threshold") | |
| topk = gr.Slider(0, 300, value=0, step=1, label="Top-K") | |
| with gr.Row(): | |
| height = gr.Slider(256, 1024, value=512, step=64, label="Height") | |
| width = gr.Slider(256, 1024, value=512, step=64, label="Width") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π Comparison View") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("**Previous**") | |
| prev_image = gr.Image(label="Previous", type="pil", height=450) | |
| prev_prompt_display = gr.Textbox(label="Previous Prompt", lines=3, interactive=False) | |
| with gr.Column(): | |
| gr.Markdown("**Current**") | |
| current_image = gr.Image(label="Current", type="pil", height=450) | |
| gr.Markdown("### π Example") | |
| gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input) | |
| generate_btn.click( | |
| fn=model.generate_image, | |
| inputs=[prompt_input, threshold, topk, height, width, guidance_scale, true_cfg_scale, num_steps, seed], | |
| outputs=[prev_image, current_image, prev_prompt_display] | |
| ) | |
| reset_btn.click(fn=model.reset_history, outputs=[prev_image, current_image, prev_prompt_display]) | |
| if __name__ == "__main__": | |
| demo.launch() |