Shaoan commited on
Commit
860b278
Β·
verified Β·
1 Parent(s): fd9e0dc

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +184 -173
app.py CHANGED
@@ -6,7 +6,7 @@ Downloads weights from model repo at startup
6
  import torch
7
  import gradio as gr
8
  import os
9
- from huggingface_hub import hf_hub_download
10
  from safetensors.torch import load_file
11
  from aligner import ConceptAligner
12
  from text_encoder import LoraT5Embedder
@@ -14,143 +14,154 @@ from pipeline import CustomFluxKontextPipeline
14
  from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
15
  from peft import LoraConfig
16
 
 
 
 
 
 
 
 
 
17
  # Configuration
18
  MODEL_REPO = "Shaoan/ConceptAligner-Weights"
19
  CHECKPOINT_DIR = "./checkpoint"
20
 
21
  EXAMPLE_PROMPTS = [
22
- [
23
- """In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""]
24
  ]
25
 
26
-
27
  def download_checkpoint():
28
- """Download checkpoint files from HF model repo"""
29
- print("Downloading checkpoint files...")
30
-
31
- files = ["model.safetensors", "model_1.safetensors", "model_2.safetensors", "empty_pooled_clip.pt"]
32
- os.makedirs(CHECKPOINT_DIR, exist_ok=True)
33
-
34
- for filename in files:
35
- local_path = os.path.join(CHECKPOINT_DIR, filename)
36
- if not os.path.exists(local_path):
37
- print(f" Downloading {filename}...")
38
- hf_hub_download(
39
- repo_id=MODEL_REPO,
40
- filename=filename,
41
- local_dir=CHECKPOINT_DIR,
42
- local_dir_use_symlinks=False
43
- )
44
-
45
- print("βœ“ All files ready!")
46
-
47
 
48
  class ConceptAlignerModel:
49
- def __init__(self):
50
- download_checkpoint()
51
-
52
- self.checkpoint_path = CHECKPOINT_DIR
53
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
54
- self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
55
-
56
- self.previous_image = None
57
- self.previous_prompt = None
58
-
59
- self.setup_models()
60
-
61
- def setup_models(self):
62
- """Load all models"""
63
- print(f"Loading models on {self.device}...")
64
-
65
- # Load ConceptAligner
66
- self.model = ConceptAligner().to(self.device).to(self.dtype)
67
- adapter_state = load_file(os.path.join(self.checkpoint_path, "model_1.safetensors"))
68
- self.model.load_state_dict(adapter_state, strict=True)
69
-
70
- # Load T5 encoder
71
- self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
72
- adapter_state = load_file(os.path.join(self.checkpoint_path, "model_2.safetensors"))
73
- if "t5_encoder.shared.weight" in adapter_state:
74
- adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
75
- self.text_encoder.load_state_dict(adapter_state, strict=True)
76
-
77
- # Load VAE
78
- vae = AutoencoderKL.from_pretrained(
79
- 'black-forest-labs/FLUX.1-dev', subfolder="vae", torch_dtype=self.dtype
80
- ).to(self.device)
81
-
82
- # Load transformer
83
- transformer = FluxTransformer2DModel.from_pretrained(
84
- 'black-forest-labs/FLUX.1-dev', subfolder="transformer", torch_dtype=self.dtype
85
- )
86
-
87
- transformer_lora_config = LoraConfig(
88
- r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
89
- target_modules=[
90
- "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
91
- "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
92
- "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
93
- "proj_mlp", "proj_out", "norm.linear", "norm1.linear"
94
- ],
95
- )
96
- transformer.add_adapter(transformer_lora_config)
97
- transformer.context_embedder.requires_grad_(True)
98
-
99
- transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
100
- transformer.load_state_dict(transformer_state, strict=True)
101
- transformer = transformer.to(self.device)
102
-
103
- # Load empty pooled clip
104
- self.empty_pooled_clip = torch.load(
105
- os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"),
106
- map_location=self.device
107
- ).to(self.dtype)
108
-
109
- # Create pipeline
110
- noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
111
- 'black-forest-labs/FLUX.1-dev', subfolder="scheduler"
112
- )
113
-
114
- self.pipe = CustomFluxKontextPipeline(
115
- scheduler=noise_scheduler,
116
- aligner=self.model,
117
- transformer=transformer,
118
- vae=vae,
119
- text_embedder=self.text_encoder,
120
- ).to(self.device)
121
-
122
- print("βœ“ Model loaded!")
123
-
124
- @torch.no_grad()
125
- def generate_image(self, prompt, threshold=0.0, topk=0, height=512, width=512,
126
- guidance_scale=3.5, true_cf_scale=1.0, num_inference_steps=20, seed=1995):
127
- if not prompt.strip():
128
- return self.previous_image, None, self.previous_prompt or ""
129
-
130
- try:
131
- generator = torch.Generator(device=self.device).manual_seed(int(seed))
132
- current_image = self.pipe(
133
- prompt=prompt, guidance_scale=guidance_scale, true_cfg_scale=true_cf_scale,
134
- max_sequence_length=512, num_inference_steps=num_inference_steps,
135
- height=height, width=width, generator=generator,
136
- ).images[0]
137
-
138
- prev_image = self.previous_image
139
- prev_prompt = self.previous_prompt or "No previous generation"
140
-
141
- self.previous_image = current_image
142
- self.previous_prompt = prompt
143
-
144
- return prev_image, current_image, prev_prompt
145
- except Exception as e:
146
- print(f"Error: {e}")
147
- return self.previous_image, None, self.previous_prompt or ""
148
-
149
- def reset_history(self):
150
- self.previous_image = None
151
- self.previous_prompt = None
152
- return None, None, "No previous generation"
153
-
 
 
 
 
 
154
 
155
  # Initialize model
156
  print("Initializing ConceptAligner...")
@@ -158,50 +169,50 @@ model = ConceptAlignerModel()
158
 
159
  # Create Gradio interface
160
  with gr.Blocks(title="ConceptAligner", theme=gr.themes.Soft()) as demo:
161
- gr.Markdown("# 🎨 ConceptAligner Demo\nGenerate images with fine-tuned concept alignment!")
162
-
163
- with gr.Row():
164
- with gr.Column(scale=1):
165
- prompt_input = gr.Textbox(label="Prompt", lines=6, placeholder="Describe your image...")
166
-
167
- with gr.Row():
168
- generate_btn = gr.Button("✨ Generate", variant="primary", size="lg", scale=3)
169
- reset_btn = gr.Button("πŸ”„ Reset", variant="secondary", size="lg", scale=1)
170
-
171
- with gr.Accordion("βš™οΈ Settings", open=True):
172
- guidance_scale = gr.Slider(1.0, 10.0, value=3.5, step=0.5, label="Guidance Scale")
173
- num_steps = gr.Slider(10, 50, value=20, step=1, label="Steps")
174
- seed = gr.Number(value=0, label="Seed", precision=0)
175
-
176
- with gr.Accordion("πŸ”¬ Advanced", open=False):
177
- true_cfg_scale = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="True CFG")
178
- threshold = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Threshold")
179
- topk = gr.Slider(0, 300, value=0, step=1, label="Top-K")
180
- with gr.Row():
181
- height = gr.Slider(256, 1024, value=512, step=64, label="Height")
182
- width = gr.Slider(256, 1024, value=512, step=64, label="Width")
183
-
184
- with gr.Column(scale=2):
185
- gr.Markdown("### πŸ“Š Comparison View")
186
- with gr.Row():
187
- with gr.Column():
188
- gr.Markdown("**Previous**")
189
- prev_image = gr.Image(label="Previous", type="pil", height=450)
190
- prev_prompt_display = gr.Textbox(label="Previous Prompt", lines=3, interactive=False)
191
- with gr.Column():
192
- gr.Markdown("**Current**")
193
- current_image = gr.Image(label="Current", type="pil", height=450)
194
-
195
- gr.Markdown("### πŸ“ Example")
196
- gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input)
197
-
198
- generate_btn.click(
199
- fn=model.generate_image,
200
- inputs=[prompt_input, threshold, topk, height, width, guidance_scale, true_cfg_scale, num_steps, seed],
201
- outputs=[prev_image, current_image, prev_prompt_display]
202
- )
203
-
204
- reset_btn.click(fn=model.reset_history, outputs=[prev_image, current_image, prev_prompt_display])
205
 
206
  if __name__ == "__main__":
207
- demo.launch()
 
6
  import torch
7
  import gradio as gr
8
  import os
9
+ from huggingface_hub import hf_hub_download, login
10
  from safetensors.torch import load_file
11
  from aligner import ConceptAligner
12
  from text_encoder import LoraT5Embedder
 
14
  from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
15
  from peft import LoraConfig
16
 
17
+ # Login with token from environment
18
+ HF_TOKEN = os.environ.get("HF_TOKEN")
19
+ if HF_TOKEN:
20
+ login(token=HF_TOKEN)
21
+ print("βœ“ Logged in to Hugging Face")
22
+ else:
23
+ print("⚠️ Warning: No HF_TOKEN found in environment")
24
+
25
  # Configuration
26
  MODEL_REPO = "Shaoan/ConceptAligner-Weights"
27
  CHECKPOINT_DIR = "./checkpoint"
28
 
29
  EXAMPLE_PROMPTS = [
30
+ ["""In the image, a single white duck walks proudly across a cobblestone street. It wears a red ribbon around its neck, and the morning sun glints off puddles from a recent rain. In the background, a few people watch and smile, giving the scene a playful charm. The duck's confident stride and upright posture make it appear oddly dignified."""]
 
31
  ]
32
 
 
33
  def download_checkpoint():
34
+ """Download checkpoint files from HF model repo"""
35
+ print("Downloading checkpoint files...")
36
+
37
+ files = ["model.safetensors", "model_1.safetensors", "model_2.safetensors", "empty_pooled_clip.pt"]
38
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
39
+
40
+ for filename in files:
41
+ local_path = os.path.join(CHECKPOINT_DIR, filename)
42
+ if not os.path.exists(local_path):
43
+ print(f" Downloading {filename}...")
44
+ hf_hub_download(
45
+ repo_id=MODEL_REPO,
46
+ filename=filename,
47
+ local_dir=CHECKPOINT_DIR,
48
+ local_dir_use_symlinks=False,
49
+ token=HF_TOKEN
50
+ )
51
+
52
+ print("βœ“ All files ready!")
53
 
54
  class ConceptAlignerModel:
55
+ def __init__(self):
56
+ download_checkpoint()
57
+
58
+ self.checkpoint_path = CHECKPOINT_DIR
59
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
60
+ self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
61
+
62
+ self.previous_image = None
63
+ self.previous_prompt = None
64
+
65
+ self.setup_models()
66
+
67
+ def setup_models(self):
68
+ """Load all models"""
69
+ print(f"Loading models on {self.device}...")
70
+
71
+ # Load ConceptAligner
72
+ self.model = ConceptAligner().to(self.device).to(self.dtype)
73
+ adapter_state = load_file(os.path.join(self.checkpoint_path, "model_1.safetensors"))
74
+ self.model.load_state_dict(adapter_state, strict=True)
75
+
76
+ # Load T5 encoder
77
+ self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
78
+ adapter_state = load_file(os.path.join(self.checkpoint_path, "model_2.safetensors"))
79
+ if "t5_encoder.shared.weight" in adapter_state:
80
+ adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
81
+ self.text_encoder.load_state_dict(adapter_state, strict=True)
82
+
83
+ # Load VAE with token
84
+ vae = AutoencoderKL.from_pretrained(
85
+ 'black-forest-labs/FLUX.1-dev',
86
+ subfolder="vae",
87
+ torch_dtype=self.dtype,
88
+ token=HF_TOKEN
89
+ ).to(self.device)
90
+
91
+ # Load transformer with token
92
+ transformer = FluxTransformer2DModel.from_pretrained(
93
+ 'black-forest-labs/FLUX.1-dev',
94
+ subfolder="transformer",
95
+ torch_dtype=self.dtype,
96
+ token=HF_TOKEN
97
+ )
98
+
99
+ transformer_lora_config = LoraConfig(
100
+ r=256, lora_alpha=256, lora_dropout=0.0, init_lora_weights="gaussian",
101
+ target_modules=[
102
+ "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
103
+ "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
104
+ "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
105
+ "proj_mlp", "proj_out", "norm.linear", "norm1.linear"
106
+ ],
107
+ )
108
+ transformer.add_adapter(transformer_lora_config)
109
+ transformer.context_embedder.requires_grad_(True)
110
+
111
+ transformer_state = load_file(os.path.join(self.checkpoint_path, "model.safetensors"))
112
+ transformer.load_state_dict(transformer_state, strict=True)
113
+ transformer = transformer.to(self.device)
114
+
115
+ # Load empty pooled clip
116
+ self.empty_pooled_clip = torch.load(
117
+ os.path.join(self.checkpoint_path, "empty_pooled_clip.pt"),
118
+ map_location=self.device
119
+ ).to(self.dtype)
120
+
121
+ # Create pipeline
122
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
123
+ 'black-forest-labs/FLUX.1-dev', subfolder="scheduler", token=HF_TOKEN
124
+ )
125
+
126
+ self.pipe = CustomFluxKontextPipeline(
127
+ scheduler=noise_scheduler,
128
+ aligner=self.model,
129
+ transformer=transformer,
130
+ vae=vae,
131
+ text_embedder=self.text_encoder,
132
+ ).to(self.device)
133
+
134
+ print("βœ“ Model loaded!")
135
+
136
+ @torch.no_grad()
137
+ def generate_image(self, prompt, threshold=0.0, topk=0, height=512, width=512,
138
+ guidance_scale=3.5, true_cf_scale=1.0, num_inference_steps=20, seed=1995):
139
+ if not prompt.strip():
140
+ return self.previous_image, None, self.previous_prompt or ""
141
+
142
+ try:
143
+ generator = torch.Generator(device=self.device).manual_seed(int(seed))
144
+ current_image = self.pipe(
145
+ prompt=prompt, guidance_scale=guidance_scale, true_cfg_scale=true_cf_scale,
146
+ max_sequence_length=512, num_inference_steps=num_inference_steps,
147
+ height=height, width=width, generator=generator,
148
+ ).images[0]
149
+
150
+ prev_image = self.previous_image
151
+ prev_prompt = self.previous_prompt or "No previous generation"
152
+
153
+ self.previous_image = current_image
154
+ self.previous_prompt = prompt
155
+
156
+ return prev_image, current_image, prev_prompt
157
+ except Exception as e:
158
+ print(f"Error: {e}")
159
+ return self.previous_image, None, self.previous_prompt or ""
160
+
161
+ def reset_history(self):
162
+ self.previous_image = None
163
+ self.previous_prompt = None
164
+ return None, None, "No previous generation"
165
 
166
  # Initialize model
167
  print("Initializing ConceptAligner...")
 
169
 
170
  # Create Gradio interface
171
  with gr.Blocks(title="ConceptAligner", theme=gr.themes.Soft()) as demo:
172
+ gr.Markdown("# 🎨 ConceptAligner Demo\nGenerate images with fine-tuned concept alignment!")
173
+
174
+ with gr.Row():
175
+ with gr.Column(scale=1):
176
+ prompt_input = gr.Textbox(label="Prompt", lines=6, placeholder="Describe your image...")
177
+
178
+ with gr.Row():
179
+ generate_btn = gr.Button("✨ Generate", variant="primary", size="lg", scale=3)
180
+ reset_btn = gr.Button("πŸ”„ Reset", variant="secondary", size="lg", scale=1)
181
+
182
+ with gr.Accordion("βš™οΈ Settings", open=True):
183
+ guidance_scale = gr.Slider(1.0, 10.0, value=3.5, step=0.5, label="Guidance Scale")
184
+ num_steps = gr.Slider(10, 50, value=20, step=1, label="Steps")
185
+ seed = gr.Number(value=0, label="Seed", precision=0)
186
+
187
+ with gr.Accordion("πŸ”¬ Advanced", open=False):
188
+ true_cfg_scale = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="True CFG")
189
+ threshold = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Threshold")
190
+ topk = gr.Slider(0, 300, value=0, step=1, label="Top-K")
191
+ with gr.Row():
192
+ height = gr.Slider(256, 1024, value=512, step=64, label="Height")
193
+ width = gr.Slider(256, 1024, value=512, step=64, label="Width")
194
+
195
+ with gr.Column(scale=2):
196
+ gr.Markdown("### πŸ“Š Comparison View")
197
+ with gr.Row():
198
+ with gr.Column():
199
+ gr.Markdown("**Previous**")
200
+ prev_image = gr.Image(label="Previous", type="pil", height=450)
201
+ prev_prompt_display = gr.Textbox(label="Previous Prompt", lines=3, interactive=False)
202
+ with gr.Column():
203
+ gr.Markdown("**Current**")
204
+ current_image = gr.Image(label="Current", type="pil", height=450)
205
+
206
+ gr.Markdown("### πŸ“ Example")
207
+ gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input)
208
+
209
+ generate_btn.click(
210
+ fn=model.generate_image,
211
+ inputs=[prompt_input, threshold, topk, height, width, guidance_scale, true_cfg_scale, num_steps, seed],
212
+ outputs=[prev_image, current_image, prev_prompt_display]
213
+ )
214
+
215
+ reset_btn.click(fn=model.reset_history, outputs=[prev_image, current_image, prev_prompt_display])
216
 
217
  if __name__ == "__main__":
218
+ demo.launch()