Shaoan commited on
Commit
36f6af4
Β·
verified Β·
1 Parent(s): 642f8f3

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +187 -190
app.py CHANGED
@@ -1,4 +1,10 @@
 
 
 
 
 
1
  import torch
 
2
  import os
3
  from huggingface_hub import hf_hub_download
4
  from safetensors.torch import load_file
@@ -7,204 +13,195 @@ from text_encoder import LoraT5Embedder
7
  from pipeline import CustomFluxKontextPipeline
8
  from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
9
  from peft import LoraConfig
10
- import gradio as gr
11
 
12
  # Configuration
13
- MODEL_REPO = "Shaoan/ConceptAligner-Weights" # Your model repo
14
  CHECKPOINT_DIR = "./checkpoint"
15
 
 
 
 
 
 
 
16
  def download_checkpoint():
17
- """Download checkpoint files from HF model repo"""
18
- print("Downloading checkpoint files...")
19
-
20
- files = [
21
- "model.safetensors",
22
- "model_1.safetensors",
23
- "model_2.safetensors"
24
- ]
25
-
26
- os.makedirs(CHECKPOINT_DIR, exist_ok=True)
27
-
28
- for filename in files:
29
- local_path = os.path.join(CHECKPOINT_DIR, filename)
30
- if not os.path.exists(local_path):
31
- print(f" Downloading {filename}...")
32
- hf_hub_download(
33
- repo_id=MODEL_REPO,
34
- filename=filename,
35
- local_dir=CHECKPOINT_DIR,
36
- local_dir_use_symlinks=False
37
- )
38
- print(f" βœ“ {filename} downloaded")
39
-
40
- print("βœ“ All checkpoint files ready!")
41
 
42
  class ConceptAlignerModel:
43
- def __init__(self):
44
- # Download checkpoint first
45
- download_checkpoint()
46
-
47
- self.checkpoint_path = CHECKPOINT_DIR
48
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
49
- self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
50
-
51
- self.previous_image = None
52
- self.previous_prompt = None
53
-
54
- print(f"\n{'='*60}")
55
- print(f"Loading ConceptAligner Model")
56
- print(f"Device: {self.device}")
57
- print(f"{'='*60}")
58
-
59
- self.setup_models()
60
-
61
- def setup_models(self):
62
- """Load all models"""
63
- # Load ConceptAligner
64
- print(f" Loading ConceptAligner...")
65
- self.model = ConceptAligner().to(self.device).to(self.dtype)
66
- adapter_path = os.path.join(self.checkpoint_path, "model_1.safetensors")
67
- adapter_state = load_file(adapter_path)
68
- self.model.load_state_dict(adapter_state, strict=True)
69
- print(f" βœ“ Adapter loaded")
70
-
71
- # Load T5 encoder
72
- print(f" Loading T5 encoder...")
73
- self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
74
- adapter_path = os.path.join(self.checkpoint_path, "model_2.safetensors")
75
- adapter_state = load_file(adapter_path)
76
- if "t5_encoder.shared.weight" in adapter_state and "t5_encoder.encoder.embed_tokens.weight" not in adapter_state:
77
- adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
78
- self.text_encoder.load_state_dict(adapter_state, strict=True)
79
- print(f" βœ“ T5 Adapter loaded")
80
-
81
- # Load VAE
82
- print(f" Loading VAE...")
83
- vae = AutoencoderKL.from_pretrained(
84
- 'black-forest-labs/FLUX.1-dev',
85
- subfolder="vae",
86
- torch_dtype=self.dtype
87
- ).to(self.device)
88
-
89
- # Load transformer
90
- print(f" Loading transformer...")
91
- transformer = FluxTransformer2DModel.from_pretrained(
92
- 'black-forest-labs/FLUX.1-dev',
93
- subfolder="transformer",
94
- torch_dtype=self.dtype
95
- )
96
-
97
- target_modules = [
98
- "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
99
- "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
100
- "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
101
- "proj_mlp", "proj_out", "norm.linear", "norm1.linear"
102
- ]
103
-
104
- transformer_lora_config = LoraConfig(
105
- r=256,
106
- lora_alpha=256,
107
- lora_dropout=0.0,
108
- init_lora_weights="gaussian",
109
- target_modules=target_modules,
110
- )
111
- transformer.add_adapter(transformer_lora_config)
112
- transformer.context_embedder.requires_grad_(True)
113
-
114
- # Load fine-tuned transformer
115
- transformer_path = os.path.join(self.checkpoint_path, "model.safetensors")
116
- transformer_state = load_file(transformer_path)
117
- transformer.load_state_dict(transformer_state, strict=True)
118
- print(f" βœ“ Fine-tuned transformer loaded")
119
-
120
- transformer = transformer.to(self.device)
121
-
122
- # Load or download empty pooled clip
123
- empty_clip_path = "empty_pooled_clip.pt"
124
- if not os.path.exists(empty_clip_path):
125
- print(" Downloading empty_pooled_clip.pt...")
126
- hf_hub_download(
127
- repo_id=MODEL_REPO,
128
- filename="empty_pooled_clip.pt",
129
- local_dir=".",
130
- local_dir_use_symlinks=False
131
- )
132
-
133
- self.empty_pooled_clip = torch.load(empty_clip_path, map_location=self.device).to(self.dtype)
134
-
135
- # Create pipeline
136
- noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
137
- 'black-forest-labs/FLUX.1-dev', subfolder="scheduler"
138
- )
139
-
140
- self.pipe = CustomFluxKontextPipeline(
141
- scheduler=noise_scheduler,
142
- aligner=self.model.to(self.device).to(self.dtype),
143
- transformer=transformer.to(self.device).to(self.dtype),
144
- vae=vae.to(self.device).to(self.dtype),
145
- text_embedder=self.text_encoder.to(self.device).to(self.dtype),
146
- ).to(self.device)
147
-
148
- if torch.cuda.is_available():
149
- allocated = torch.cuda.memory_allocated(0) / 1024**3
150
- reserved = torch.cuda.memory_reserved(0) / 1024**3
151
- print(f" βœ“ Pipeline ready on {self.device}")
152
- print(f" πŸ“Š GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
153
- else:
154
- print(f" βœ“ Pipeline ready on {self.device}")
155
-
156
- @torch.no_grad()
157
- def generate_image(
158
- self,
159
- prompt,
160
- threshold=0.0,
161
- topk=0,
162
- height=512,
163
- width=512,
164
- guidance_scale=3.5,
165
- true_cf_scale=1.0,
166
- num_inference_steps=20,
167
- seed=1995
168
- ):
169
- """Generate image and return previous + current for comparison"""
170
- if not prompt.strip():
171
- return self.previous_image, None, self.previous_prompt or ""
172
-
173
- try:
174
- generator = torch.Generator(device=self.device).manual_seed(int(seed))
175
-
176
- current_image = self.pipe(
177
- prompt=prompt,
178
- guidance_scale=guidance_scale,
179
- true_cfg_scale=true_cf_scale,
180
- max_sequence_length=512,
181
- num_inference_steps=num_inference_steps,
182
- height=height,
183
- width=width,
184
- generator=generator,
185
- ).images[0]
186
-
187
- prev_image = self.previous_image
188
- prev_prompt = self.previous_prompt or "No previous generation"
189
-
190
- self.previous_image = current_image
191
- self.previous_prompt = prompt
192
-
193
- return prev_image, current_image, prev_prompt
194
-
195
- except Exception as e:
196
- import traceback
197
- error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
198
- print(error_msg)
199
- return self.previous_image, None, self.previous_prompt or ""
200
-
201
- def reset_history(self):
202
- """Clear generation history"""
203
- self.previous_image = None
204
- self.previous_prompt = None
205
- return None, None, "No previous generation"
206
 
207
 
208
  # Initialize model
209
- print("Initializing ConceptAligner model...")
210
  model = ConceptAlignerModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ConceptAligner Hugging Face Demo
3
+ Downloads weights from model repo at startup
4
+ """
5
+
6
  import torch
7
+ import gradio as gr
8
  import os
9
  from huggingface_hub import hf_hub_download
10
  from safetensors.torch import load_file
 
13
  from pipeline import CustomFluxKontextPipeline
14
  from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
15
  from peft import LoraConfig
 
16
 
17
  # Configuration
18
+ MODEL_REPO = "Shaoan/ConceptAligner-Weights"
19
  CHECKPOINT_DIR = "./checkpoint"
20
 
21
+ EXAMPLE_PROMPTS = [
22
+ [
23
+ """In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""]
24
+ ]
25
+
26
+
27
  def download_checkpoint():
28
+ """Download checkpoint files from HF model repo"""
29
+ print("Downloading checkpoint files...")
30
+
31
+ files = ["model.safetensors", "model_1.safetensors", "model_2.safetensors", "empty_pooled_clip.pt"]
32
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
33
+
34
+ for filename in files:
35
+ local_path = os.path.join(CHECKPOINT_DIR, filename)
36
+ if not os.path.exists(local_path):
37
+ print(f" Downloading {filename}...")
38
+ hf_hub_download(
39
+ repo_id=MODEL_REPO,
40
+ filename=filename,
41
+ local_dir=CHECKPOINT_DIR,
42
+ local_dir_use_symlinks=False
43
+ )
44
+
45
+ print("βœ“ All files ready!")
46
+
 
 
 
 
 
47
 
48
  class ConceptAlignerModel:
49
+ def __init__(self):
50
+ download_checkpoint()
51
+
52
+ self.checkpoint_path = CHECKPOINT_DIR
53
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
54
+ self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
55
+
56
+ self.previous_image = None
57
+ self.previous_prompt = None
58
+
59
+ self.setup_models()
60
+
61
+ def setup_models(self):
62
+ """Load all models"""
63
+ print(f"Loading models on {self.device}...")
64
+
65
+ # Load ConceptAligner
66
+ self.model = ConceptAligner().to(self.device).to(self.dtype)
67
+ adapter_state = load_file(os.path.join(self.checkpoint_path, "model_1.safetensors"))
68
+ self.model.load_state_dict(adapter_state, strict=True)
69
+
70
+ # Load T5 encoder
71
+ self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
72
+ adapter_state = load_file(os.path.join(self.checkpoint_path, "model_2.safetensors"))
73
+ if "t5_encoder.shared.weight" in adapter_state:
74
+ adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
75
+ self.text_encoder.load_state_dict(adapter_state, strict=True)
76
+
77
+ # Load VAE
78
+ vae = AutoencoderKL.from_pretrained(
79
+ 'black-forest-labs/FLUX.1-dev', subfolder="vae", torch_dtype=self.dtype
80
+ ).to(self.device)
81
+
82
+ # Load transformer
83
+ transformer = FluxTransformer2DModel.from_pretrained(
84
+ 'black-forest-labs/FLUX.1-dev', subfolder="transformer", torch_dtype=self.dtype
85
+ )
86
+
87
+ transformer_lora_config = LoraConfig(
88
+ r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
89
+ target_modules=[
90
+ "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
91
+ "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
92
+ "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
93
+ "proj_mlp", "proj_out", "norm.linear", "norm1.linear"
94
+ ],
95
+ )
96
+ transformer.add_adapter(transformer_lora_config)
97
+ transformer.context_embedder.requires_grad_(True)
98
+
99
+ transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
100
+ transformer.load_state_dict(transformer_state, strict=True)
101
+ transformer = transformer.to(self.device)
102
+
103
+ # Load empty pooled clip
104
+ self.empty_pooled_clip = torch.load(
105
+ os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"),
106
+ map_location=self.device
107
+ ).to(self.dtype)
108
+
109
+ # Create pipeline
110
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
111
+ 'black-forest-labs/FLUX.1-dev', subfolder="scheduler"
112
+ )
113
+
114
+ self.pipe = CustomFluxKontextPipeline(
115
+ scheduler=noise_scheduler,
116
+ aligner=self.model,
117
+ transformer=transformer,
118
+ vae=vae,
119
+ text_embedder=self.text_encoder,
120
+ ).to(self.device)
121
+
122
+ print("βœ“ Model loaded!")
123
+
124
+ @torch.no_grad()
125
+ def generate_image(self, prompt, threshold=0.0, topk=0, height=512, width=512,
126
+ guidance_scale=3.5, true_cf_scale=1.0, num_inference_steps=20, seed=1995):
127
+ if not prompt.strip():
128
+ return self.previous_image, None, self.previous_prompt or ""
129
+
130
+ try:
131
+ generator = torch.Generator(device=self.device).manual_seed(int(seed))
132
+ current_image = self.pipe(
133
+ prompt=prompt, guidance_scale=guidance_scale, true_cfg_scale=true_cf_scale,
134
+ max_sequence_length=512, num_inference_steps=num_inference_steps,
135
+ height=height, width=width, generator=generator,
136
+ ).images[0]
137
+
138
+ prev_image = self.previous_image
139
+ prev_prompt = self.previous_prompt or "No previous generation"
140
+
141
+ self.previous_image = current_image
142
+ self.previous_prompt = prompt
143
+
144
+ return prev_image, current_image, prev_prompt
145
+ except Exception as e:
146
+ print(f"Error: {e}")
147
+ return self.previous_image, None, self.previous_prompt or ""
148
+
149
+ def reset_history(self):
150
+ self.previous_image = None
151
+ self.previous_prompt = None
152
+ return None, None, "No previous generation"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
 
155
  # Initialize model
156
+ print("Initializing ConceptAligner...")
157
  model = ConceptAlignerModel()
158
+
159
+ # Create Gradio interface
160
+ with gr.Blocks(title="ConceptAligner", theme=gr.themes.Soft()) as demo:
161
+ gr.Markdown("# 🎨 ConceptAligner Demo\nGenerate images with fine-tuned concept alignment!")
162
+
163
+ with gr.Row():
164
+ with gr.Column(scale=1):
165
+ prompt_input = gr.Textbox(label="Prompt", lines=6, placeholder="Describe your image...")
166
+
167
+ with gr.Row():
168
+ generate_btn = gr.Button("✨ Generate", variant="primary", size="lg", scale=3)
169
+ reset_btn = gr.Button("πŸ”„ Reset", variant="secondary", size="lg", scale=1)
170
+
171
+ with gr.Accordion("βš™οΈ Settings", open=True):
172
+ guidance_scale = gr.Slider(1.0, 10.0, value=3.5, step=0.5, label="Guidance Scale")
173
+ num_steps = gr.Slider(10, 50, value=20, step=1, label="Steps")
174
+ seed = gr.Number(value=0, label="Seed", precision=0)
175
+
176
+ with gr.Accordion("πŸ”¬ Advanced", open=False):
177
+ true_cfg_scale = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="True CFG")
178
+ threshold = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Threshold")
179
+ topk = gr.Slider(0, 300, value=0, step=1, label="Top-K")
180
+ with gr.Row():
181
+ height = gr.Slider(256, 1024, value=512, step=64, label="Height")
182
+ width = gr.Slider(256, 1024, value=512, step=64, label="Width")
183
+
184
+ with gr.Column(scale=2):
185
+ gr.Markdown("### πŸ“Š Comparison View")
186
+ with gr.Row():
187
+ with gr.Column():
188
+ gr.Markdown("**Previous**")
189
+ prev_image = gr.Image(label="Previous", type="pil", height=450)
190
+ prev_prompt_display = gr.Textbox(label="Previous Prompt", lines=3, interactive=False)
191
+ with gr.Column():
192
+ gr.Markdown("**Current**")
193
+ current_image = gr.Image(label="Current", type="pil", height=450)
194
+
195
+ gr.Markdown("### πŸ“ Example")
196
+ gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input)
197
+
198
+ generate_btn.click(
199
+ fn=model.generate_image,
200
+ inputs=[prompt_input, threshold, topk, height, width, guidance_scale, true_cfg_scale, num_steps, seed],
201
+ outputs=[prev_image, current_image, prev_prompt_display]
202
+ )
203
+
204
+ reset_btn.click(fn=model.reset_history, outputs=[prev_image, current_image, prev_prompt_display])
205
+
206
+ if __name__ == "__main__":
207
+ demo.launch()