Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from aligner import ConceptAligner | |
| from text_encoder import LoraT5Embedder | |
| from pipeline import CustomFluxKontextPipeline | |
| from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL | |
| from peft import LoraConfig | |
| import gradio as gr | |
| # Configuration | |
| MODEL_REPO = "Shaoan/ConceptAligner-Weights" # Your model repo | |
| CHECKPOINT_DIR = "./checkpoint" | |
| def download_checkpoint(): | |
| """Download checkpoint files from HF model repo""" | |
| print("Downloading checkpoint files...") | |
| files = [ | |
| "model.safetensors", | |
| "model_1.safetensors", | |
| "model_2.safetensors" | |
| ] | |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) | |
| for filename in files: | |
| local_path = os.path.join(CHECKPOINT_DIR, filename) | |
| if not os.path.exists(local_path): | |
| print(f" Downloading {filename}...") | |
| hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=filename, | |
| local_dir=CHECKPOINT_DIR, | |
| local_dir_use_symlinks=False | |
| ) | |
| print(f" β {filename} downloaded") | |
| print("β All checkpoint files ready!") | |
| class ConceptAlignerModel: | |
| def __init__(self): | |
| # Download checkpoint first | |
| download_checkpoint() | |
| self.checkpoint_path = CHECKPOINT_DIR | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| self.previous_image = None | |
| self.previous_prompt = None | |
| print(f"\n{'='*60}") | |
| print(f"Loading ConceptAligner Model") | |
| print(f"Device: {self.device}") | |
| print(f"{'='*60}") | |
| self.setup_models() | |
| def setup_models(self): | |
| """Load all models""" | |
| # Load ConceptAligner | |
| print(f" Loading ConceptAligner...") | |
| self.model = ConceptAligner().to(self.device).to(self.dtype) | |
| adapter_path = os.path.join(self.checkpoint_path, "model_1.safetensors") | |
| adapter_state = load_file(adapter_path) | |
| self.model.load_state_dict(adapter_state, strict=True) | |
| print(f" β Adapter loaded") | |
| # Load T5 encoder | |
| print(f" Loading T5 encoder...") | |
| self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype) | |
| adapter_path = os.path.join(self.checkpoint_path, "model_2.safetensors") | |
| adapter_state = load_file(adapter_path) | |
| if "t5_encoder.shared.weight" in adapter_state and "t5_encoder.encoder.embed_tokens.weight" not in adapter_state: | |
| adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"] | |
| self.text_encoder.load_state_dict(adapter_state, strict=True) | |
| print(f" β T5 Adapter loaded") | |
| # Load VAE | |
| print(f" Loading VAE...") | |
| vae = AutoencoderKL.from_pretrained( | |
| 'black-forest-labs/FLUX.1-dev', | |
| subfolder="vae", | |
| torch_dtype=self.dtype | |
| ).to(self.device) | |
| # Load transformer | |
| print(f" Loading transformer...") | |
| transformer = FluxTransformer2DModel.from_pretrained( | |
| 'black-forest-labs/FLUX.1-dev', | |
| subfolder="transformer", | |
| torch_dtype=self.dtype | |
| ) | |
| target_modules = [ | |
| "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0", | |
| "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out", | |
| "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2", | |
| "proj_mlp", "proj_out", "norm.linear", "norm1.linear" | |
| ] | |
| transformer_lora_config = LoraConfig( | |
| r=256, | |
| lora_alpha=256, | |
| lora_dropout=0.0, | |
| init_lora_weights="gaussian", | |
| target_modules=target_modules, | |
| ) | |
| transformer.add_adapter(transformer_lora_config) | |
| transformer.context_embedder.requires_grad_(True) | |
| # Load fine-tuned transformer | |
| transformer_path = os.path.join(self.checkpoint_path, "model.safetensors") | |
| transformer_state = load_file(transformer_path) | |
| transformer.load_state_dict(transformer_state, strict=True) | |
| print(f" β Fine-tuned transformer loaded") | |
| transformer = transformer.to(self.device) | |
| # Load or download empty pooled clip | |
| empty_clip_path = "empty_pooled_clip.pt" | |
| if not os.path.exists(empty_clip_path): | |
| print(" Downloading empty_pooled_clip.pt...") | |
| hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename="empty_pooled_clip.pt", | |
| local_dir=".", | |
| local_dir_use_symlinks=False | |
| ) | |
| self.empty_pooled_clip = torch.load(empty_clip_path, map_location=self.device).to(self.dtype) | |
| # Create pipeline | |
| noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| 'black-forest-labs/FLUX.1-dev', subfolder="scheduler" | |
| ) | |
| self.pipe = CustomFluxKontextPipeline( | |
| scheduler=noise_scheduler, | |
| aligner=self.model.to(self.device).to(self.dtype), | |
| transformer=transformer.to(self.device).to(self.dtype), | |
| vae=vae.to(self.device).to(self.dtype), | |
| text_embedder=self.text_encoder.to(self.device).to(self.dtype), | |
| ).to(self.device) | |
| if torch.cuda.is_available(): | |
| allocated = torch.cuda.memory_allocated(0) / 1024**3 | |
| reserved = torch.cuda.memory_reserved(0) / 1024**3 | |
| print(f" β Pipeline ready on {self.device}") | |
| print(f" π GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") | |
| else: | |
| print(f" β Pipeline ready on {self.device}") | |
| def generate_image( | |
| self, | |
| prompt, | |
| threshold=0.0, | |
| topk=0, | |
| height=512, | |
| width=512, | |
| guidance_scale=3.5, | |
| true_cf_scale=1.0, | |
| num_inference_steps=20, | |
| seed=1995 | |
| ): | |
| """Generate image and return previous + current for comparison""" | |
| if not prompt.strip(): | |
| return self.previous_image, None, self.previous_prompt or "" | |
| try: | |
| generator = torch.Generator(device=self.device).manual_seed(int(seed)) | |
| current_image = self.pipe( | |
| prompt=prompt, | |
| guidance_scale=guidance_scale, | |
| true_cfg_scale=true_cf_scale, | |
| max_sequence_length=512, | |
| num_inference_steps=num_inference_steps, | |
| height=height, | |
| width=width, | |
| generator=generator, | |
| ).images[0] | |
| prev_image = self.previous_image | |
| prev_prompt = self.previous_prompt or "No previous generation" | |
| self.previous_image = current_image | |
| self.previous_prompt = prompt | |
| return prev_image, current_image, prev_prompt | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"β Error: {str(e)}\n{traceback.format_exc()}" | |
| print(error_msg) | |
| return self.previous_image, None, self.previous_prompt or "" | |
| def reset_history(self): | |
| """Clear generation history""" | |
| self.previous_image = None | |
| self.previous_prompt = None | |
| return None, None, "No previous generation" | |
| # Initialize model | |
| print("Initializing ConceptAligner model...") | |
| model = ConceptAlignerModel() | |