import torch import os from huggingface_hub import hf_hub_download from safetensors.torch import load_file from aligner import ConceptAligner from text_encoder import LoraT5Embedder from pipeline import CustomFluxKontextPipeline from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL from peft import LoraConfig import gradio as gr # Configuration MODEL_REPO = "Shaoan/ConceptAligner-Weights" # Your model repo CHECKPOINT_DIR = "./checkpoint" def download_checkpoint(): """Download checkpoint files from HF model repo""" print("Downloading checkpoint files...") files = [ "model.safetensors", "model_1.safetensors", "model_2.safetensors" ] os.makedirs(CHECKPOINT_DIR, exist_ok=True) for filename in files: local_path = os.path.join(CHECKPOINT_DIR, filename) if not os.path.exists(local_path): print(f" Downloading {filename}...") hf_hub_download( repo_id=MODEL_REPO, filename=filename, local_dir=CHECKPOINT_DIR, local_dir_use_symlinks=False ) print(f" ✓ {filename} downloaded") print("✓ All checkpoint files ready!") class ConceptAlignerModel: def __init__(self): # Download checkpoint first download_checkpoint() self.checkpoint_path = CHECKPOINT_DIR self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 self.previous_image = None self.previous_prompt = None print(f"\n{'='*60}") print(f"Loading ConceptAligner Model") print(f"Device: {self.device}") print(f"{'='*60}") self.setup_models() def setup_models(self): """Load all models""" # Load ConceptAligner print(f" Loading ConceptAligner...") self.model = ConceptAligner().to(self.device).to(self.dtype) adapter_path = os.path.join(self.checkpoint_path, "model_1.safetensors") adapter_state = load_file(adapter_path) self.model.load_state_dict(adapter_state, strict=True) print(f" ✓ Adapter loaded") # Load T5 encoder print(f" Loading T5 encoder...") self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype) adapter_path = os.path.join(self.checkpoint_path, "model_2.safetensors") adapter_state = load_file(adapter_path) if "t5_encoder.shared.weight" in adapter_state and "t5_encoder.encoder.embed_tokens.weight" not in adapter_state: adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"] self.text_encoder.load_state_dict(adapter_state, strict=True) print(f" ✓ T5 Adapter loaded") # Load VAE print(f" Loading VAE...") vae = AutoencoderKL.from_pretrained( 'black-forest-labs/FLUX.1-dev', subfolder="vae", torch_dtype=self.dtype ).to(self.device) # Load transformer print(f" Loading transformer...") transformer = FluxTransformer2DModel.from_pretrained( 'black-forest-labs/FLUX.1-dev', subfolder="transformer", torch_dtype=self.dtype ) target_modules = [ "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0", "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out", "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2", "proj_mlp", "proj_out", "norm.linear", "norm1.linear" ] transformer_lora_config = LoraConfig( r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian", target_modules=target_modules, ) transformer.add_adapter(transformer_lora_config) transformer.context_embedder.requires_grad_(True) # Load fine-tuned transformer transformer_path = os.path.join(self.checkpoint_path, "model.safetensors") transformer_state = load_file(transformer_path) transformer.load_state_dict(transformer_state, strict=True) print(f" ✓ Fine-tuned transformer loaded") transformer = transformer.to(self.device) # Load or download empty pooled clip empty_clip_path = "empty_pooled_clip.pt" if not os.path.exists(empty_clip_path): print(" Downloading empty_pooled_clip.pt...") hf_hub_download( repo_id=MODEL_REPO, filename="empty_pooled_clip.pt", local_dir=".", local_dir_use_symlinks=False ) self.empty_pooled_clip = torch.load(empty_clip_path, map_location=self.device).to(self.dtype) # Create pipeline noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 'black-forest-labs/FLUX.1-dev', subfolder="scheduler" ) self.pipe = CustomFluxKontextPipeline( scheduler=noise_scheduler, aligner=self.model.to(self.device).to(self.dtype), transformer=transformer.to(self.device).to(self.dtype), vae=vae.to(self.device).to(self.dtype), text_embedder=self.text_encoder.to(self.device).to(self.dtype), ).to(self.device) if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated(0) / 1024**3 reserved = torch.cuda.memory_reserved(0) / 1024**3 print(f" ✓ Pipeline ready on {self.device}") print(f" 📊 GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") else: print(f" ✓ Pipeline ready on {self.device}") @torch.no_grad() def generate_image( self, prompt, threshold=0.0, topk=0, height=512, width=512, guidance_scale=3.5, true_cf_scale=1.0, num_inference_steps=20, seed=1995 ): """Generate image and return previous + current for comparison""" if not prompt.strip(): return self.previous_image, None, self.previous_prompt or "" try: generator = torch.Generator(device=self.device).manual_seed(int(seed)) current_image = self.pipe( prompt=prompt, guidance_scale=guidance_scale, true_cfg_scale=true_cf_scale, max_sequence_length=512, num_inference_steps=num_inference_steps, height=height, width=width, generator=generator, ).images[0] prev_image = self.previous_image prev_prompt = self.previous_prompt or "No previous generation" self.previous_image = current_image self.previous_prompt = prompt return prev_image, current_image, prev_prompt except Exception as e: import traceback error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}" print(error_msg) return self.previous_image, None, self.previous_prompt or "" def reset_history(self): """Clear generation history""" self.previous_image = None self.previous_prompt = None return None, None, "No previous generation" # Initialize model print("Initializing ConceptAligner model...") model = ConceptAlignerModel()