Shaoan commited on
Commit
7ce8df1
Β·
verified Β·
1 Parent(s): 4aacf24

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +48 -35
app.py CHANGED
@@ -1,5 +1,6 @@
1
  """
2
- ConceptAligner Hugging Face Demo - Optimized for storage
 
3
  """
4
 
5
  import torch
@@ -23,11 +24,6 @@ if HF_TOKEN:
23
  MODEL_REPO = "Shaoan/ConceptAligner-Weights"
24
  CHECKPOINT_DIR = "./checkpoint"
25
 
26
- # Use HF cache directory to avoid duplication
27
- os.environ["HF_HOME"] = "/data/.huggingface"
28
- os.environ["TRANSFORMERS_CACHE"] = "/data/.huggingface/hub"
29
- os.environ["HF_HUB_CACHE"] = "/data/.huggingface/hub"
30
-
31
  EXAMPLE_PROMPTS = [
32
  ["""In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""]
33
  ]
@@ -47,11 +43,11 @@ def download_checkpoint():
47
  repo_id=MODEL_REPO,
48
  filename=filename,
49
  local_dir=CHECKPOINT_DIR,
50
- local_dir_use_symlinks=False,
51
  token=HF_TOKEN
52
  )
 
53
 
54
- print("βœ“ Checkpoint files ready!")
55
 
56
  class ConceptAlignerModel:
57
  def __init__(self):
@@ -71,14 +67,14 @@ class ConceptAlignerModel:
71
  print(f"Loading models on {self.device}...")
72
 
73
  # Load ConceptAligner
74
- print(" Loading ConceptAligner adapter...")
75
  self.model = ConceptAligner().to(self.device).to(self.dtype)
76
  adapter_state = load_file(os.path.join(self.checkpoint_path, "model_1.safetensors"))
77
  self.model.load_state_dict(adapter_state, strict=True)
78
  print(" βœ“ ConceptAligner loaded")
79
 
80
- # Load T5 encoder
81
- print(" Loading T5 encoder adapter...")
82
  self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
83
  adapter_state = load_file(os.path.join(self.checkpoint_path, "model_2.safetensors"))
84
  if "t5_encoder.shared.weight" in adapter_state:
@@ -86,30 +82,33 @@ class ConceptAlignerModel:
86
  self.text_encoder.load_state_dict(adapter_state, strict=True)
87
  print(" βœ“ T5 encoder loaded")
88
 
89
- # Load VAE (will use shared cache)
90
  print(" Loading VAE from FLUX.1-dev...")
91
  vae = AutoencoderKL.from_pretrained(
92
  'black-forest-labs/FLUX.1-dev',
93
  subfolder="vae",
94
  torch_dtype=self.dtype,
95
- token=HF_TOKEN,
96
- cache_dir="/data/.huggingface/hub",
97
- low_cpu_mem_usage=True
98
  ).to(self.device)
99
- print(" βœ“ VAE loaded")
 
 
 
100
 
101
- # Load transformer (will use shared cache)
102
- print(" Loading transformer from FLUX.1-dev...")
103
- transformer = FluxTransformer2DModel.from_pretrained(
104
  'black-forest-labs/FLUX.1-dev',
105
  subfolder="transformer",
106
- torch_dtype=self.dtype,
107
- token=HF_TOKEN,
108
- cache_dir="/data/.huggingface/hub",
109
- low_cpu_mem_usage=True
110
  )
111
 
112
- print(" Adding LoRA adapters to transformer...")
 
 
 
 
 
113
  transformer_lora_config = LoraConfig(
114
  r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
115
  target_modules=[
@@ -122,27 +121,32 @@ class ConceptAlignerModel:
122
  transformer.add_adapter(transformer_lora_config)
123
  transformer.context_embedder.requires_grad_(True)
124
 
125
- print(" Loading fine-tuned transformer weights...")
 
126
  transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
127
  transformer.load_state_dict(transformer_state, strict=True)
128
- transformer = transformer.to(self.device)
129
- print(" βœ“ Transformer loaded")
130
 
131
  # Load empty pooled clip
132
  self.empty_pooled_clip = torch.load(
133
  os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"),
134
- map_location=self.device
 
135
  ).to(self.dtype)
 
136
 
137
- # Create pipeline
138
- print(" Creating pipeline...")
139
  noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
140
  'black-forest-labs/FLUX.1-dev',
141
  subfolder="scheduler",
142
- token=HF_TOKEN,
143
- cache_dir="/data/.huggingface/hub"
144
  )
 
145
 
 
 
146
  self.pipe = CustomFluxKontextPipeline(
147
  scheduler=noise_scheduler,
148
  aligner=self.model,
@@ -151,7 +155,11 @@ class ConceptAlignerModel:
151
  text_embedder=self.text_encoder,
152
  ).to(self.device)
153
 
154
- print("βœ“ All models loaded successfully!")
 
 
 
 
155
 
156
  # Print memory usage
157
  if torch.cuda.is_available():
@@ -193,13 +201,18 @@ class ConceptAlignerModel:
193
 
194
  # Initialize model
195
  print("="*60)
196
- print("Initializing ConceptAligner...")
197
  print("="*60)
198
  model = ConceptAlignerModel()
199
 
200
  # Create Gradio interface
201
  with gr.Blocks(title="ConceptAligner", theme=gr.themes.Soft()) as demo:
202
- gr.Markdown("# 🎨 ConceptAligner Demo\nGenerate images with fine-tuned concept alignment!")
 
 
 
 
 
203
 
204
  with gr.Row():
205
  with gr.Column(scale=1):
 
1
  """
2
+ ConceptAligner Hugging Face Demo - Minimal downloads
3
+ Only downloads VAE, uses your fine-tuned weights for everything else
4
  """
5
 
6
  import torch
 
24
  MODEL_REPO = "Shaoan/ConceptAligner-Weights"
25
  CHECKPOINT_DIR = "./checkpoint"
26
 
 
 
 
 
 
27
  EXAMPLE_PROMPTS = [
28
  ["""In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""]
29
  ]
 
43
  repo_id=MODEL_REPO,
44
  filename=filename,
45
  local_dir=CHECKPOINT_DIR,
 
46
  token=HF_TOKEN
47
  )
48
+ print(f" βœ“ {filename} downloaded")
49
 
50
+ print("βœ“ All checkpoint files ready!")
51
 
52
  class ConceptAlignerModel:
53
  def __init__(self):
 
67
  print(f"Loading models on {self.device}...")
68
 
69
  # Load ConceptAligner
70
+ print(" Loading ConceptAligner...")
71
  self.model = ConceptAligner().to(self.device).to(self.dtype)
72
  adapter_state = load_file(os.path.join(self.checkpoint_path, "model_1.safetensors"))
73
  self.model.load_state_dict(adapter_state, strict=True)
74
  print(" βœ“ ConceptAligner loaded")
75
 
76
+ # Load T5 encoder (your fine-tuned version with full weights)
77
+ print(" Loading fine-tuned T5 encoder...")
78
  self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
79
  adapter_state = load_file(os.path.join(self.checkpoint_path, "model_2.safetensors"))
80
  if "t5_encoder.shared.weight" in adapter_state:
 
82
  self.text_encoder.load_state_dict(adapter_state, strict=True)
83
  print(" βœ“ T5 encoder loaded")
84
 
85
+ # Only download VAE (small ~330MB) - not fine-tuned
86
  print(" Loading VAE from FLUX.1-dev...")
87
  vae = AutoencoderKL.from_pretrained(
88
  'black-forest-labs/FLUX.1-dev',
89
  subfolder="vae",
90
  torch_dtype=self.dtype,
91
+ token=HF_TOKEN
 
 
92
  ).to(self.device)
93
+ print(" βœ“ VAE loaded (~330MB download)")
94
+
95
+ # Create transformer architecture WITHOUT downloading base weights
96
+ print(" Initializing transformer architecture...")
97
 
98
+ # Get config only (no weights download)
99
+ from diffusers.models.transformers.transformer_flux import FluxTransformerConfig
100
+ config = FluxTransformer2DModel.load_config(
101
  'black-forest-labs/FLUX.1-dev',
102
  subfolder="transformer",
103
+ token=HF_TOKEN
 
 
 
104
  )
105
 
106
+ # Initialize empty transformer from config
107
+ transformer = FluxTransformer2DModel.from_config(config)
108
+ print(" βœ“ Transformer architecture initialized")
109
+
110
+ # Add LoRA config (needed for architecture)
111
+ print(" Adding LoRA adapter config...")
112
  transformer_lora_config = LoraConfig(
113
  r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
114
  target_modules=[
 
121
  transformer.add_adapter(transformer_lora_config)
122
  transformer.context_embedder.requires_grad_(True)
123
 
124
+ # Load YOUR FULL fine-tuned transformer weights (no base model needed!)
125
+ print(" Loading YOUR fine-tuned transformer weights...")
126
  transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
127
  transformer.load_state_dict(transformer_state, strict=True)
128
+ transformer = transformer.to(self.device).to(self.dtype)
129
+ print(" βœ“ Fine-tuned transformer loaded (~26GB from your checkpoint)")
130
 
131
  # Load empty pooled clip
132
  self.empty_pooled_clip = torch.load(
133
  os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"),
134
+ map_location=self.device,
135
+ weights_only=True
136
  ).to(self.dtype)
137
+ print(" βœ“ Empty pooled clip loaded")
138
 
139
+ # Create scheduler (just config, no weights)
140
+ print(" Loading scheduler config...")
141
  noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
142
  'black-forest-labs/FLUX.1-dev',
143
  subfolder="scheduler",
144
+ token=HF_TOKEN
 
145
  )
146
+ print(" βœ“ Scheduler loaded")
147
 
148
+ # Create pipeline
149
+ print(" Creating pipeline...")
150
  self.pipe = CustomFluxKontextPipeline(
151
  scheduler=noise_scheduler,
152
  aligner=self.model,
 
155
  text_embedder=self.text_encoder,
156
  ).to(self.device)
157
 
158
+ print("="*60)
159
+ print("βœ“ ALL MODELS LOADED SUCCESSFULLY!")
160
+ print("="*60)
161
+ print(f"Total downloads: ~330MB VAE + ~26GB your checkpoints")
162
+ print(f"Saved: ~24GB by not downloading base FLUX transformer!")
163
 
164
  # Print memory usage
165
  if torch.cuda.is_available():
 
201
 
202
  # Initialize model
203
  print("="*60)
204
+ print("πŸš€ Initializing ConceptAligner Demo")
205
  print("="*60)
206
  model = ConceptAlignerModel()
207
 
208
  # Create Gradio interface
209
  with gr.Blocks(title="ConceptAligner", theme=gr.themes.Soft()) as demo:
210
+ gr.Markdown("""
211
+ # 🎨 ConceptAligner Demo
212
+
213
+ Generate images with fine-tuned concept alignment using FLUX!
214
+ This demo uses fully fine-tuned weights - no base model downloads needed.
215
+ """)
216
 
217
  with gr.Row():
218
  with gr.Column(scale=1):