Shaoan commited on
Commit
4aacf24
Β·
verified Β·
1 Parent(s): 860b278

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +42 -12
app.py CHANGED
@@ -1,6 +1,5 @@
1
  """
2
- ConceptAligner Hugging Face Demo
3
- Downloads weights from model repo at startup
4
  """
5
 
6
  import torch
@@ -19,13 +18,16 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
19
  if HF_TOKEN:
20
  login(token=HF_TOKEN)
21
  print("βœ“ Logged in to Hugging Face")
22
- else:
23
- print("⚠️ Warning: No HF_TOKEN found in environment")
24
 
25
  # Configuration
26
  MODEL_REPO = "Shaoan/ConceptAligner-Weights"
27
  CHECKPOINT_DIR = "./checkpoint"
28
 
 
 
 
 
 
29
  EXAMPLE_PROMPTS = [
30
  ["""In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""]
31
  ]
@@ -49,7 +51,7 @@ def download_checkpoint():
49
  token=HF_TOKEN
50
  )
51
 
52
- print("βœ“ All files ready!")
53
 
54
  class ConceptAlignerModel:
55
  def __init__(self):
@@ -69,33 +71,45 @@ class ConceptAlignerModel:
69
  print(f"Loading models on {self.device}...")
70
 
71
  # Load ConceptAligner
 
72
  self.model = ConceptAligner().to(self.device).to(self.dtype)
73
  adapter_state = load_file(os.path.join(self.checkpoint_path, "model_1.safetensors"))
74
  self.model.load_state_dict(adapter_state, strict=True)
 
75
 
76
  # Load T5 encoder
 
77
  self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
78
  adapter_state = load_file(os.path.join(self.checkpoint_path, "model_2.safetensors"))
79
  if "t5_encoder.shared.weight" in adapter_state:
80
  adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
81
  self.text_encoder.load_state_dict(adapter_state, strict=True)
 
82
 
83
- # Load VAE with token
 
84
  vae = AutoencoderKL.from_pretrained(
85
  'black-forest-labs/FLUX.1-dev',
86
  subfolder="vae",
87
  torch_dtype=self.dtype,
88
- token=HF_TOKEN
 
 
89
  ).to(self.device)
 
90
 
91
- # Load transformer with token
 
92
  transformer = FluxTransformer2DModel.from_pretrained(
93
  'black-forest-labs/FLUX.1-dev',
94
  subfolder="transformer",
95
  torch_dtype=self.dtype,
96
- token=HF_TOKEN
 
 
97
  )
98
 
 
99
  transformer_lora_config = LoraConfig(
100
  r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
101
  target_modules=[
@@ -108,9 +122,11 @@ class ConceptAlignerModel:
108
  transformer.add_adapter(transformer_lora_config)
109
  transformer.context_embedder.requires_grad_(True)
110
 
 
111
  transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
112
  transformer.load_state_dict(transformer_state, strict=True)
113
  transformer = transformer.to(self.device)
 
114
 
115
  # Load empty pooled clip
116
  self.empty_pooled_clip = torch.load(
@@ -119,8 +135,12 @@ class ConceptAlignerModel:
119
  ).to(self.dtype)
120
 
121
  # Create pipeline
 
122
  noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
123
- 'black-forest-labs/FLUX.1-dev', subfolder="scheduler", token=HF_TOKEN
 
 
 
124
  )
125
 
126
  self.pipe = CustomFluxKontextPipeline(
@@ -131,7 +151,13 @@ class ConceptAlignerModel:
131
  text_embedder=self.text_encoder,
132
  ).to(self.device)
133
 
134
- print("βœ“ Model loaded!")
 
 
 
 
 
 
135
 
136
  @torch.no_grad()
137
  def generate_image(self, prompt, threshold=0.0, topk=0, height=512, width=512,
@@ -155,7 +181,9 @@ class ConceptAlignerModel:
155
 
156
  return prev_image, current_image, prev_prompt
157
  except Exception as e:
158
- print(f"Error: {e}")
 
 
159
  return self.previous_image, None, self.previous_prompt or ""
160
 
161
  def reset_history(self):
@@ -164,7 +192,9 @@ class ConceptAlignerModel:
164
  return None, None, "No previous generation"
165
 
166
  # Initialize model
 
167
  print("Initializing ConceptAligner...")
 
168
  model = ConceptAlignerModel()
169
 
170
  # Create Gradio interface
 
1
  """
2
+ ConceptAligner Hugging Face Demo - Optimized for storage
 
3
  """
4
 
5
  import torch
 
18
  if HF_TOKEN:
19
  login(token=HF_TOKEN)
20
  print("βœ“ Logged in to Hugging Face")
 
 
21
 
22
  # Configuration
23
  MODEL_REPO = "Shaoan/ConceptAligner-Weights"
24
  CHECKPOINT_DIR = "./checkpoint"
25
 
26
+ # Use HF cache directory to avoid duplication
27
+ os.environ["HF_HOME"] = "/data/.huggingface"
28
+ os.environ["TRANSFORMERS_CACHE"] = "/data/.huggingface/hub"
29
+ os.environ["HF_HUB_CACHE"] = "/data/.huggingface/hub"
30
+
31
  EXAMPLE_PROMPTS = [
32
  ["""In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""]
33
  ]
 
51
  token=HF_TOKEN
52
  )
53
 
54
+ print("βœ“ Checkpoint files ready!")
55
 
56
  class ConceptAlignerModel:
57
  def __init__(self):
 
71
  print(f"Loading models on {self.device}...")
72
 
73
  # Load ConceptAligner
74
+ print(" Loading ConceptAligner adapter...")
75
  self.model = ConceptAligner().to(self.device).to(self.dtype)
76
  adapter_state = load_file(os.path.join(self.checkpoint_path, "model_1.safetensors"))
77
  self.model.load_state_dict(adapter_state, strict=True)
78
+ print(" βœ“ ConceptAligner loaded")
79
 
80
  # Load T5 encoder
81
+ print(" Loading T5 encoder adapter...")
82
  self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
83
  adapter_state = load_file(os.path.join(self.checkpoint_path, "model_2.safetensors"))
84
  if "t5_encoder.shared.weight" in adapter_state:
85
  adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
86
  self.text_encoder.load_state_dict(adapter_state, strict=True)
87
+ print(" βœ“ T5 encoder loaded")
88
 
89
+ # Load VAE (will use shared cache)
90
+ print(" Loading VAE from FLUX.1-dev...")
91
  vae = AutoencoderKL.from_pretrained(
92
  'black-forest-labs/FLUX.1-dev',
93
  subfolder="vae",
94
  torch_dtype=self.dtype,
95
+ token=HF_TOKEN,
96
+ cache_dir="/data/.huggingface/hub",
97
+ low_cpu_mem_usage=True
98
  ).to(self.device)
99
+ print(" βœ“ VAE loaded")
100
 
101
+ # Load transformer (will use shared cache)
102
+ print(" Loading transformer from FLUX.1-dev...")
103
  transformer = FluxTransformer2DModel.from_pretrained(
104
  'black-forest-labs/FLUX.1-dev',
105
  subfolder="transformer",
106
  torch_dtype=self.dtype,
107
+ token=HF_TOKEN,
108
+ cache_dir="/data/.huggingface/hub",
109
+ low_cpu_mem_usage=True
110
  )
111
 
112
+ print(" Adding LoRA adapters to transformer...")
113
  transformer_lora_config = LoraConfig(
114
  r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
115
  target_modules=[
 
122
  transformer.add_adapter(transformer_lora_config)
123
  transformer.context_embedder.requires_grad_(True)
124
 
125
+ print(" Loading fine-tuned transformer weights...")
126
  transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
127
  transformer.load_state_dict(transformer_state, strict=True)
128
  transformer = transformer.to(self.device)
129
+ print(" βœ“ Transformer loaded")
130
 
131
  # Load empty pooled clip
132
  self.empty_pooled_clip = torch.load(
 
135
  ).to(self.dtype)
136
 
137
  # Create pipeline
138
+ print(" Creating pipeline...")
139
  noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
140
+ 'black-forest-labs/FLUX.1-dev',
141
+ subfolder="scheduler",
142
+ token=HF_TOKEN,
143
+ cache_dir="/data/.huggingface/hub"
144
  )
145
 
146
  self.pipe = CustomFluxKontextPipeline(
 
151
  text_embedder=self.text_encoder,
152
  ).to(self.device)
153
 
154
+ print("βœ“ All models loaded successfully!")
155
+
156
+ # Print memory usage
157
+ if torch.cuda.is_available():
158
+ allocated = torch.cuda.memory_allocated(0) / 1024**3
159
+ reserved = torch.cuda.memory_reserved(0) / 1024**3
160
+ print(f"πŸ“Š GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
161
 
162
  @torch.no_grad()
163
  def generate_image(self, prompt, threshold=0.0, topk=0, height=512, width=512,
 
181
 
182
  return prev_image, current_image, prev_prompt
183
  except Exception as e:
184
+ import traceback
185
+ print(f"❌ Generation error: {e}")
186
+ print(traceback.format_exc())
187
  return self.previous_image, None, self.previous_prompt or ""
188
 
189
  def reset_history(self):
 
192
  return None, None, "No previous generation"
193
 
194
  # Initialize model
195
+ print("="*60)
196
  print("Initializing ConceptAligner...")
197
+ print("="*60)
198
  model = ConceptAlignerModel()
199
 
200
  # Create Gradio interface