Shaoan commited on
Commit
789491e
·
verified ·
1 Parent(s): 5c693f5

Upload folder using huggingface_hub

Browse files
Files changed (24) hide show
  1. .gitattributes +20 -0
  2. aligner.py +19 -16
  3. app.py +12 -12
  4. boy.jpg +3 -0
  5. dog.jpg +3 -0
  6. dragon.jpg +3 -0
  7. dump.jpg +3 -0
  8. egg.jpg +3 -0
  9. elephant.jpg +3 -0
  10. family.jpg +3 -0
  11. fold.jpg +3 -0
  12. fruit.jpg +3 -0
  13. girl.jpg +3 -0
  14. girl2.jpg +3 -0
  15. laion.jpg +3 -0
  16. lizard.jpg +3 -0
  17. pole.jpg +3 -0
  18. requirements.txt +1 -1
  19. robot.jpg +3 -0
  20. robot2.jpg +3 -0
  21. robot3.jpg +3 -0
  22. sky.jpg +3 -0
  23. whale.jpg +3 -0
  24. wood.jpg +3 -0
.gitattributes CHANGED
@@ -33,3 +33,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ boy.jpg filter=lfs diff=lfs merge=lfs -text
37
+ dog.jpg filter=lfs diff=lfs merge=lfs -text
38
+ dragon.jpg filter=lfs diff=lfs merge=lfs -text
39
+ dump.jpg filter=lfs diff=lfs merge=lfs -text
40
+ egg.jpg filter=lfs diff=lfs merge=lfs -text
41
+ elephant.jpg filter=lfs diff=lfs merge=lfs -text
42
+ family.jpg filter=lfs diff=lfs merge=lfs -text
43
+ fold.jpg filter=lfs diff=lfs merge=lfs -text
44
+ fruit.jpg filter=lfs diff=lfs merge=lfs -text
45
+ girl.jpg filter=lfs diff=lfs merge=lfs -text
46
+ girl2.jpg filter=lfs diff=lfs merge=lfs -text
47
+ laion.jpg filter=lfs diff=lfs merge=lfs -text
48
+ lizard.jpg filter=lfs diff=lfs merge=lfs -text
49
+ pole.jpg filter=lfs diff=lfs merge=lfs -text
50
+ robot.jpg filter=lfs diff=lfs merge=lfs -text
51
+ robot2.jpg filter=lfs diff=lfs merge=lfs -text
52
+ robot3.jpg filter=lfs diff=lfs merge=lfs -text
53
+ sky.jpg filter=lfs diff=lfs merge=lfs -text
54
+ whale.jpg filter=lfs diff=lfs merge=lfs -text
55
+ wood.jpg filter=lfs diff=lfs merge=lfs -text
aligner.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  from torch import nn
3
- #from refiner import Qwen2Connector
4
 
5
  import torch
6
  import torch.nn as nn
@@ -425,11 +425,11 @@ class ConceptAligner(nn.Module):
425
  empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu')
426
  self.register_buffer('empty_pooled_clip', empty_pooled_clip)
427
 
428
- test_eps = torch.randn([1, 300, per_dim], dtype=torch.bfloat16).to('cpu')*0.7
429
  self.register_buffer('test_eps', test_eps)
430
 
431
- self.init_proj = nn.Sequential(nn.Linear(768, 300*16), nn.SiLU())
432
- self.proj = nn.Sequential(nn.Linear(16, 1024), nn.SiLU(),
433
  nn.Linear(1024, 1024), nn.SiLU())
434
  self.text_proj = nn.Sequential(nn.Linear(4096, 1024), nn.SiLU(),
435
  nn.Linear(1024, 1024), nn.SiLU())
@@ -465,7 +465,7 @@ class ConceptAligner(nn.Module):
465
  device = text_features.device
466
 
467
  if image_features is not None:
468
- visual_hidden = self.proj(self.init_proj(image_features).view(len(image_features), 300, -1))
469
  text_hidden = self.text_proj(text_features.detach())
470
  hidden = visual_hidden - text_hidden
471
  mu = self.proj_mu(hidden)
@@ -510,13 +510,14 @@ if __name__ == '__main__':
510
 
511
  dim = 4096
512
  num_heads = 32
 
513
  dtype = torch.bfloat16
514
  model = ConceptAligner().to('cuda').to(dtype)
515
- x = torch.randn([5, 300, dim]).to('cuda').to(dtype)
516
- y = torch.randn([5, 300, dim]).to('cuda').to(dtype)
517
  i = torch.randn([5,768]).to('cuda').to(dtype)
518
  y[1] = y[0]
519
- m = torch.ones([5, 300]).to('cuda').to(dtype)
520
  m[:3,:128] = 0
521
  prompt_embeds, _, pooled_prompt_embeds, text_ids, aux_info = model(x, i)
522
  print(prompt_embeds.shape, pooled_prompt_embeds.shape, text_ids.shape)
@@ -524,14 +525,15 @@ if __name__ == '__main__':
524
  for k in aux_info:
525
  print(k, aux_info[k].shape, aux_info[k].min(), aux_info[k].max(), aux_info[k].mean())
526
 
527
- from text_encoder import LoraT5Embedder
528
  from datasets import load_dataset
529
- dataset = load_dataset("facebook/emu_edit_test_set", split='validation[:200]')
530
- item = dataset[0:4]
531
- another_item = dataset[0:4]
 
532
  from diffusers.models.normalization import RMSNorm
533
  clip_processor = AutoProcessor.from_pretrained("./clip-vit-large-patch14")
534
- clip_images = clip_processor(images=item['image'], return_tensors="pt").pixel_values.to('cuda:0').to(dtype)
535
  texts = []
536
  texts.append("""A heartwarming 3D rendered scene of
537
  an elderly farmer and a tiny orange
@@ -567,8 +569,8 @@ if __name__ == '__main__':
567
  texts.append(
568
  """In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a clear blue sky outside, with the silhouettes of trees swaying gently in the distance. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and homey, capturing a tranquil afternoon moment of quiet observation.""")
569
 
570
- text_encoder = LoraT5Embedder(device='cuda').to(dtype)
571
- text_features, _, _, _, image_features, _ = text_encoder(texts, clip_images)
572
  print(text_features.shape, image_features.shape, ' >>>>>>>>> text input')
573
  images = []
574
  pipe = FluxPipeline.from_pretrained("./FLUX.1-dev", dtype=torch.bfloat16, text_encoder=None).to(torch.bfloat16)
@@ -576,7 +578,7 @@ if __name__ == '__main__':
576
 
577
  for txt_feat, img_feat in zip(text_features, image_features):
578
 
579
- prompt_embeds, _, pooled_prompt_embeds, text_ids, aux_info = model(txt_feat.unsqueeze(0), img_feat.unsqueeze(0))
580
  image = pipe(
581
  prompt_embeds=prompt_embeds,
582
  pooled_prompt_embeds=pooled_prompt_embeds,
@@ -816,6 +818,7 @@ glow reminiscent of the glow of the moon. HD,
816
  # for (start_dim, end_dim) in [(0,4096), (1024,4096), (2048, 4096), (1024, 2048)]:
817
 
818
 
 
819
  for emb in ['floral', 'golden']:
820
  for temp in [2.5]:
821
  for thr in [-1, 0.5, 0.75, 0.85, 0.95]:
 
1
  import torch
2
  from torch import nn
3
+ from refiner import Qwen2Connector
4
 
5
  import torch
6
  import torch.nn as nn
 
425
  empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu')
426
  self.register_buffer('empty_pooled_clip', empty_pooled_clip)
427
 
428
+ test_eps = torch.randn([1, 512, per_dim], dtype=torch.bfloat16).to('cpu')*0.7
429
  self.register_buffer('test_eps', test_eps)
430
 
431
+ self.init_proj = nn.Sequential(nn.Linear(768, 512*8), nn.SiLU())
432
+ self.proj = nn.Sequential(nn.Linear(8, 1024), nn.SiLU(),
433
  nn.Linear(1024, 1024), nn.SiLU())
434
  self.text_proj = nn.Sequential(nn.Linear(4096, 1024), nn.SiLU(),
435
  nn.Linear(1024, 1024), nn.SiLU())
 
465
  device = text_features.device
466
 
467
  if image_features is not None:
468
+ visual_hidden = self.proj(self.init_proj(image_features).view(len(image_features), text_features.size(1), -1))
469
  text_hidden = self.text_proj(text_features.detach())
470
  hidden = visual_hidden - text_hidden
471
  mu = self.proj_mu(hidden)
 
510
 
511
  dim = 4096
512
  num_heads = 32
513
+
514
  dtype = torch.bfloat16
515
  model = ConceptAligner().to('cuda').to(dtype)
516
+ x = torch.randn([5, 512, dim]).to('cuda').to(dtype)
517
+ y = torch.randn([5, 512, dim]).to('cuda').to(dtype)
518
  i = torch.randn([5,768]).to('cuda').to(dtype)
519
  y[1] = y[0]
520
+ m = torch.ones([5, 512]).to('cuda').to(dtype)
521
  m[:3,:128] = 0
522
  prompt_embeds, _, pooled_prompt_embeds, text_ids, aux_info = model(x, i)
523
  print(prompt_embeds.shape, pooled_prompt_embeds.shape, text_ids.shape)
 
525
  for k in aux_info:
526
  print(k, aux_info[k].shape, aux_info[k].min(), aux_info[k].max(), aux_info[k].mean())
527
 
528
+ from text_encoder import JacobianLoraT5Embedder
529
  from datasets import load_dataset
530
+ #dataset = load_dataset("facebook/emu_edit_test_set", split='validation[:200]')
531
+ #item = dataset[0:4]
532
+ #another_item = dataset[0:4]
533
+ image = Image.open('example512.jpg').convert('RGB')
534
  from diffusers.models.normalization import RMSNorm
535
  clip_processor = AutoProcessor.from_pretrained("./clip-vit-large-patch14")
536
+ clip_images = clip_processor(images=image, return_tensors="pt").pixel_values.to('cuda:0').to(dtype).repeat(4,1,1,1)
537
  texts = []
538
  texts.append("""A heartwarming 3D rendered scene of
539
  an elderly farmer and a tiny orange
 
569
  texts.append(
570
  """In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a clear blue sky outside, with the silhouettes of trees swaying gently in the distance. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and homey, capturing a tranquil afternoon moment of quiet observation.""")
571
 
572
+ text_encoder = JacobianLoraT5Embedder(device='cuda', max_length=512, num_jacobian_samples=1).to(torch.bfloat16)
573
+ text_features, image_features, _, _ = text_encoder(texts, clip_images)
574
  print(text_features.shape, image_features.shape, ' >>>>>>>>> text input')
575
  images = []
576
  pipe = FluxPipeline.from_pretrained("./FLUX.1-dev", dtype=torch.bfloat16, text_encoder=None).to(torch.bfloat16)
 
578
 
579
  for txt_feat, img_feat in zip(text_features, image_features):
580
 
581
+ prompt_embeds, _, pooled_prompt_embeds, text_ids, aux_info = model(txt_feat.unsqueeze(0), None)
582
  image = pipe(
583
  prompt_embeds=prompt_embeds,
584
  pooled_prompt_embeds=pooled_prompt_embeds,
 
818
  # for (start_dim, end_dim) in [(0,4096), (1024,4096), (2048, 4096), (1024, 2048)]:
819
 
820
 
821
+
822
  for emb in ['floral', 'golden']:
823
  for temp in [2.5]:
824
  for thr in [-1, 0.5, 0.75, 0.85, 0.95]:
app.py CHANGED
@@ -180,7 +180,7 @@ def reset_history():
180
  css = """
181
  #col-container {
182
  margin: 0 auto;
183
- max-width: 1400px;
184
  }
185
  """
186
 
@@ -194,15 +194,17 @@ with gr.Blocks(css=css, title="ConceptAligner") as demo:
194
 
195
  with gr.Row():
196
  with gr.Column(scale=1):
197
- prompt_input = gr.Textbox(
198
  label="Prompt",
199
- lines=8,
200
- placeholder="Describe your image in detail...",
 
 
201
  )
202
 
203
  with gr.Row():
204
  generate_btn = gr.Button("✨ Generate", variant="primary", scale=3)
205
- reset_btn = gr.Button("🔄 Clear History", variant="secondary", scale=1)
206
 
207
  with gr.Accordion("⚙️ Settings", open=False):
208
  seed = gr.Slider(
@@ -214,21 +216,19 @@ with gr.Blocks(css=css, title="ConceptAligner") as demo:
214
  )
215
 
216
  guidance_scale = gr.Slider(
217
- label="Guidance Scale",
218
  minimum=1.0,
219
  maximum=10.0,
220
  step=0.5,
221
  value=3.5,
222
- info="Higher = follows prompt more closely (3-4 recommended)"
223
  )
224
 
225
  num_inference_steps = gr.Slider(
226
- label="Number of Steps",
227
  minimum=10,
228
  maximum=50,
229
  step=1,
230
  value=20,
231
- info="More steps = higher quality but slower"
232
  )
233
 
234
  with gr.Row():
@@ -263,17 +263,17 @@ with gr.Blocks(css=css, title="ConceptAligner") as demo:
263
  with gr.Row():
264
  with gr.Column():
265
  gr.Markdown("**Previous**")
266
- prev_image = gr.Image(label="Previous", show_label=False, type="pil", height=450)
267
  prev_prompt_display = gr.Textbox(
268
  label="Previous Prompt",
269
- lines=3,
270
  interactive=False,
271
  show_label=False
272
  )
273
 
274
  with gr.Column():
275
  gr.Markdown("**Latest**")
276
- current_image = gr.Image(label="Current", show_label=False, type="pil", height=450)
277
 
278
  gr.Markdown("### 📝 Try This Example")
279
  gr.Examples(
 
180
  css = """
181
  #col-container {
182
  margin: 0 auto;
183
+ max-width: 1200px;
184
  }
185
  """
186
 
 
194
 
195
  with gr.Row():
196
  with gr.Column(scale=1):
197
+ prompt_input = gr.Text(
198
  label="Prompt",
199
+ show_label=False,
200
+ max_lines=3,
201
+ placeholder="Describe your image...",
202
+ container=False,
203
  )
204
 
205
  with gr.Row():
206
  generate_btn = gr.Button("✨ Generate", variant="primary", scale=3)
207
+ reset_btn = gr.Button("🔄 Clear", variant="secondary", scale=1)
208
 
209
  with gr.Accordion("⚙️ Settings", open=False):
210
  seed = gr.Slider(
 
216
  )
217
 
218
  guidance_scale = gr.Slider(
219
+ label="Creativity Level",
220
  minimum=1.0,
221
  maximum=10.0,
222
  step=0.5,
223
  value=3.5,
 
224
  )
225
 
226
  num_inference_steps = gr.Slider(
227
+ label="Quality (steps)",
228
  minimum=10,
229
  maximum=50,
230
  step=1,
231
  value=20,
 
232
  )
233
 
234
  with gr.Row():
 
263
  with gr.Row():
264
  with gr.Column():
265
  gr.Markdown("**Previous**")
266
+ prev_image = gr.Image(label="Previous", show_label=False, type="pil", height=400)
267
  prev_prompt_display = gr.Textbox(
268
  label="Previous Prompt",
269
+ lines=2,
270
  interactive=False,
271
  show_label=False
272
  )
273
 
274
  with gr.Column():
275
  gr.Markdown("**Latest**")
276
+ current_image = gr.Image(label="Current", show_label=False, type="pil", height=400)
277
 
278
  gr.Markdown("### 📝 Try This Example")
279
  gr.Examples(
boy.jpg ADDED

Git LFS Details

  • SHA256: c1bfe31cd434a1d95056268d0b5bcc15901434696b5964366063cadd33b4d8b1
  • Pointer size: 131 Bytes
  • Size of remote file: 146 kB
dog.jpg ADDED

Git LFS Details

  • SHA256: 336efade4f6239f6b4b2e5a158860ac97e6da8c9dff1270ec24bc63f083f4447
  • Pointer size: 131 Bytes
  • Size of remote file: 165 kB
dragon.jpg ADDED

Git LFS Details

  • SHA256: d0f3bc03b863c39445a1e4926c1e6c669d76eb772e457d08d48cd12c24e9cc53
  • Pointer size: 131 Bytes
  • Size of remote file: 231 kB
dump.jpg ADDED

Git LFS Details

  • SHA256: 92ccbd6d558cb241887f72c50a6692378b90589d1edde08d16398d59925bd7c6
  • Pointer size: 131 Bytes
  • Size of remote file: 153 kB
egg.jpg ADDED

Git LFS Details

  • SHA256: 17d1316c1a8b651020aef76c52986ac93a306b0ed4705fa8facdebb93b525d78
  • Pointer size: 131 Bytes
  • Size of remote file: 216 kB
elephant.jpg ADDED

Git LFS Details

  • SHA256: 82999fef6cedd1f202c67ce83680c9d71ae65a32c110f26223b2b93b80d9d8b3
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
family.jpg ADDED

Git LFS Details

  • SHA256: d6e8926e7e8ac3df2d7ffa7ca01b35a4bc6c1ba194d3f99312525bd5c19cb3fa
  • Pointer size: 131 Bytes
  • Size of remote file: 407 kB
fold.jpg ADDED

Git LFS Details

  • SHA256: 742562a13a3b3d1bb68b0ae4f8f0f86219e87e971b5d8ea09c1977c87b6b8f91
  • Pointer size: 131 Bytes
  • Size of remote file: 211 kB
fruit.jpg ADDED

Git LFS Details

  • SHA256: 892be3f440c55076ba5cb7801aa21908de35169648e6f7461976b2594facb34e
  • Pointer size: 131 Bytes
  • Size of remote file: 169 kB
girl.jpg ADDED

Git LFS Details

  • SHA256: 9bbecfd34c6d24766712b6e4f17efe1c43c6bdfba4968a93cf108ff715fce2f7
  • Pointer size: 131 Bytes
  • Size of remote file: 146 kB
girl2.jpg ADDED

Git LFS Details

  • SHA256: 5e683ae1ba4f37489ef64d8fbed17f26d97df0d2b65109b2194ea8b657b0c473
  • Pointer size: 131 Bytes
  • Size of remote file: 482 kB
laion.jpg ADDED

Git LFS Details

  • SHA256: 6e0a3f68b61af2a20a291db616617647adfab6e21e13176289558e350dc7431a
  • Pointer size: 131 Bytes
  • Size of remote file: 199 kB
lizard.jpg ADDED

Git LFS Details

  • SHA256: 4c419ea424a520b74284749a2877ec51e327ce878da28b47ad3f3eb1ad0095db
  • Pointer size: 131 Bytes
  • Size of remote file: 219 kB
pole.jpg ADDED

Git LFS Details

  • SHA256: 2499bf2ce65a0be8fe8dbab8702ca07e92014152dc7d7c03aa1a12f97a886740
  • Pointer size: 131 Bytes
  • Size of remote file: 140 kB
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  torch==2.8.0
2
  torchvision==0.23.0
3
- gradio==3.50.2
4
  diffusers==0.35.2
5
  transformers==4.57.1
6
  safetensors==0.6.2
 
1
  torch==2.8.0
2
  torchvision==0.23.0
3
+ gradio==6.9.0
4
  diffusers==0.35.2
5
  transformers==4.57.1
6
  safetensors==0.6.2
robot.jpg ADDED

Git LFS Details

  • SHA256: 9c946e843bcdcf4d90b43e51dc7b0fc3082fbfcf10810096dbea4b151271eb19
  • Pointer size: 131 Bytes
  • Size of remote file: 219 kB
robot2.jpg ADDED

Git LFS Details

  • SHA256: 1e6079e438552356abd0e7433b7a1e135e5a3a273efeebc7c314057ff6dba705
  • Pointer size: 131 Bytes
  • Size of remote file: 213 kB
robot3.jpg ADDED

Git LFS Details

  • SHA256: 94995fbd36f63156733076a7dd163e92b4cb539c9e626c982562b1b482c32c16
  • Pointer size: 131 Bytes
  • Size of remote file: 243 kB
sky.jpg ADDED

Git LFS Details

  • SHA256: d9524732b0357c4202e737d44beffa5c2e9072b1b674a5eec28a195e034c276c
  • Pointer size: 131 Bytes
  • Size of remote file: 206 kB
whale.jpg ADDED

Git LFS Details

  • SHA256: 93ded9538ba19e0fe49271126827ffa5f580ac6a181548a8a5b1644738009e48
  • Pointer size: 131 Bytes
  • Size of remote file: 281 kB
wood.jpg ADDED

Git LFS Details

  • SHA256: 23852ccdf0df403395ce061991e59932784a046a8c5d456df33883d82e723d46
  • Pointer size: 131 Bytes
  • Size of remote file: 378 kB