ConceptAligner / app.py
Shaoan's picture
Upload folder using huggingface_hub
642f8f3 verified
raw
history blame
7.79 kB
import torch
import os
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from aligner import ConceptAligner
from text_encoder import LoraT5Embedder
from pipeline import CustomFluxKontextPipeline
from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
from peft import LoraConfig
import gradio as gr
# Configuration
MODEL_REPO = "Shaoan/ConceptAligner-Weights" # Your model repo
CHECKPOINT_DIR = "./checkpoint"
def download_checkpoint():
"""Download checkpoint files from HF model repo"""
print("Downloading checkpoint files...")
files = [
"model.safetensors",
"model_1.safetensors",
"model_2.safetensors"
]
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
for filename in files:
local_path = os.path.join(CHECKPOINT_DIR, filename)
if not os.path.exists(local_path):
print(f" Downloading {filename}...")
hf_hub_download(
repo_id=MODEL_REPO,
filename=filename,
local_dir=CHECKPOINT_DIR,
local_dir_use_symlinks=False
)
print(f" βœ“ {filename} downloaded")
print("βœ“ All checkpoint files ready!")
class ConceptAlignerModel:
def __init__(self):
# Download checkpoint first
download_checkpoint()
self.checkpoint_path = CHECKPOINT_DIR
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
self.previous_image = None
self.previous_prompt = None
print(f"\n{'='*60}")
print(f"Loading ConceptAligner Model")
print(f"Device: {self.device}")
print(f"{'='*60}")
self.setup_models()
def setup_models(self):
"""Load all models"""
# Load ConceptAligner
print(f" Loading ConceptAligner...")
self.model = ConceptAligner().to(self.device).to(self.dtype)
adapter_path = os.path.join(self.checkpoint_path, "model_1.safetensors")
adapter_state = load_file(adapter_path)
self.model.load_state_dict(adapter_state, strict=True)
print(f" βœ“ Adapter loaded")
# Load T5 encoder
print(f" Loading T5 encoder...")
self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
adapter_path = os.path.join(self.checkpoint_path, "model_2.safetensors")
adapter_state = load_file(adapter_path)
if "t5_encoder.shared.weight" in adapter_state and "t5_encoder.encoder.embed_tokens.weight" not in adapter_state:
adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
self.text_encoder.load_state_dict(adapter_state, strict=True)
print(f" βœ“ T5 Adapter loaded")
# Load VAE
print(f" Loading VAE...")
vae = AutoencoderKL.from_pretrained(
'black-forest-labs/FLUX.1-dev',
subfolder="vae",
torch_dtype=self.dtype
).to(self.device)
# Load transformer
print(f" Loading transformer...")
transformer = FluxTransformer2DModel.from_pretrained(
'black-forest-labs/FLUX.1-dev',
subfolder="transformer",
torch_dtype=self.dtype
)
target_modules = [
"attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
"attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
"ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
"proj_mlp", "proj_out", "norm.linear", "norm1.linear"
]
transformer_lora_config = LoraConfig(
r=256,
lora_alpha=256,
lora_dropout=0.0,
init_lora_weights="gaussian",
target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)
transformer.context_embedder.requires_grad_(True)
# Load fine-tuned transformer
transformer_path = os.path.join(self.checkpoint_path, "model.safetensors")
transformer_state = load_file(transformer_path)
transformer.load_state_dict(transformer_state, strict=True)
print(f" βœ“ Fine-tuned transformer loaded")
transformer = transformer.to(self.device)
# Load or download empty pooled clip
empty_clip_path = "empty_pooled_clip.pt"
if not os.path.exists(empty_clip_path):
print(" Downloading empty_pooled_clip.pt...")
hf_hub_download(
repo_id=MODEL_REPO,
filename="empty_pooled_clip.pt",
local_dir=".",
local_dir_use_symlinks=False
)
self.empty_pooled_clip = torch.load(empty_clip_path, map_location=self.device).to(self.dtype)
# Create pipeline
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
'black-forest-labs/FLUX.1-dev', subfolder="scheduler"
)
self.pipe = CustomFluxKontextPipeline(
scheduler=noise_scheduler,
aligner=self.model.to(self.device).to(self.dtype),
transformer=transformer.to(self.device).to(self.dtype),
vae=vae.to(self.device).to(self.dtype),
text_embedder=self.text_encoder.to(self.device).to(self.dtype),
).to(self.device)
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated(0) / 1024**3
reserved = torch.cuda.memory_reserved(0) / 1024**3
print(f" βœ“ Pipeline ready on {self.device}")
print(f" πŸ“Š GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
else:
print(f" βœ“ Pipeline ready on {self.device}")
@torch.no_grad()
def generate_image(
self,
prompt,
threshold=0.0,
topk=0,
height=512,
width=512,
guidance_scale=3.5,
true_cf_scale=1.0,
num_inference_steps=20,
seed=1995
):
"""Generate image and return previous + current for comparison"""
if not prompt.strip():
return self.previous_image, None, self.previous_prompt or ""
try:
generator = torch.Generator(device=self.device).manual_seed(int(seed))
current_image = self.pipe(
prompt=prompt,
guidance_scale=guidance_scale,
true_cfg_scale=true_cf_scale,
max_sequence_length=512,
num_inference_steps=num_inference_steps,
height=height,
width=width,
generator=generator,
).images[0]
prev_image = self.previous_image
prev_prompt = self.previous_prompt or "No previous generation"
self.previous_image = current_image
self.previous_prompt = prompt
return prev_image, current_image, prev_prompt
except Exception as e:
import traceback
error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return self.previous_image, None, self.previous_prompt or ""
def reset_history(self):
"""Clear generation history"""
self.previous_image = None
self.previous_prompt = None
return None, None, "No previous generation"
# Initialize model
print("Initializing ConceptAligner model...")
model = ConceptAlignerModel()