RishubhPar commited on
Commit
bd90279
·
verified ·
1 Parent(s): 8b8b01e

added other files.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +61 -0
  2. model/.DS_Store +0 -0
  3. model/__init__.py +0 -0
  4. model/__pycache__/__init__.cpython-310.pyc +0 -0
  5. model/__pycache__/sliders_model.cpython-310.pyc +0 -0
  6. model/__pycache__/sliders_pipeline.cpython-310.pyc +0 -0
  7. model/__pycache__/transformer_flux.cpython-310.pyc +0 -0
  8. model/sliders_model.py +102 -0
  9. model/sliders_pipeline.py +468 -0
  10. model/transformer_flux.py +608 -0
  11. model_weights/.DS_Store +0 -0
  12. model_weights/pytorch_lora_weights.safetensors +3 -0
  13. model_weights/random_states_0.pkl +3 -0
  14. model_weights/scheduler.bin +3 -0
  15. model_weights/slider_projector.pth +3 -0
  16. sample_images/precomputed/aesthetic_model2_vangogh.png +3 -0
  17. sample_images/precomputed/aesthetic_model2_vangogh/image_0.png +3 -0
  18. sample_images/precomputed/aesthetic_model2_vangogh/image_1.png +3 -0
  19. sample_images/precomputed/aesthetic_model2_vangogh/image_10.png +3 -0
  20. sample_images/precomputed/aesthetic_model2_vangogh/image_2.png +3 -0
  21. sample_images/precomputed/aesthetic_model2_vangogh/image_3.png +3 -0
  22. sample_images/precomputed/aesthetic_model2_vangogh/image_4.png +3 -0
  23. sample_images/precomputed/aesthetic_model2_vangogh/image_5.png +3 -0
  24. sample_images/precomputed/aesthetic_model2_vangogh/image_6.png +3 -0
  25. sample_images/precomputed/aesthetic_model2_vangogh/image_7.png +3 -0
  26. sample_images/precomputed/aesthetic_model2_vangogh/image_8.png +3 -0
  27. sample_images/precomputed/aesthetic_model2_vangogh/image_9.png +3 -0
  28. sample_images/precomputed/enfield3_winter_snow.png +3 -0
  29. sample_images/precomputed/enfield3_winter_snow/image_0.png +3 -0
  30. sample_images/precomputed/enfield3_winter_snow/image_1.png +3 -0
  31. sample_images/precomputed/enfield3_winter_snow/image_10.png +3 -0
  32. sample_images/precomputed/enfield3_winter_snow/image_2.png +3 -0
  33. sample_images/precomputed/enfield3_winter_snow/image_3.png +3 -0
  34. sample_images/precomputed/enfield3_winter_snow/image_4.png +3 -0
  35. sample_images/precomputed/enfield3_winter_snow/image_5.png +3 -0
  36. sample_images/precomputed/enfield3_winter_snow/image_6.png +3 -0
  37. sample_images/precomputed/enfield3_winter_snow/image_7.png +3 -0
  38. sample_images/precomputed/enfield3_winter_snow/image_8.png +3 -0
  39. sample_images/precomputed/enfield3_winter_snow/image_9.png +3 -0
  40. sample_images/precomputed/jackson_fluffy.png +3 -0
  41. sample_images/precomputed/jackson_fluffy/image_0.png +3 -0
  42. sample_images/precomputed/jackson_fluffy/image_1.png +3 -0
  43. sample_images/precomputed/jackson_fluffy/image_10.png +3 -0
  44. sample_images/precomputed/jackson_fluffy/image_11.png +3 -0
  45. sample_images/precomputed/jackson_fluffy/image_2.png +3 -0
  46. sample_images/precomputed/jackson_fluffy/image_3.png +3 -0
  47. sample_images/precomputed/jackson_fluffy/image_4.png +3 -0
  48. sample_images/precomputed/jackson_fluffy/image_5.png +3 -0
  49. sample_images/precomputed/jackson_fluffy/image_6.png +3 -0
  50. sample_images/precomputed/jackson_fluffy/image_7.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,64 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sample_images/precomputed/aesthetic_model2_vangogh.png filter=lfs diff=lfs merge=lfs -text
37
+ sample_images/precomputed/aesthetic_model2_vangogh/image_0.png filter=lfs diff=lfs merge=lfs -text
38
+ sample_images/precomputed/aesthetic_model2_vangogh/image_1.png filter=lfs diff=lfs merge=lfs -text
39
+ sample_images/precomputed/aesthetic_model2_vangogh/image_10.png filter=lfs diff=lfs merge=lfs -text
40
+ sample_images/precomputed/aesthetic_model2_vangogh/image_2.png filter=lfs diff=lfs merge=lfs -text
41
+ sample_images/precomputed/aesthetic_model2_vangogh/image_3.png filter=lfs diff=lfs merge=lfs -text
42
+ sample_images/precomputed/aesthetic_model2_vangogh/image_4.png filter=lfs diff=lfs merge=lfs -text
43
+ sample_images/precomputed/aesthetic_model2_vangogh/image_5.png filter=lfs diff=lfs merge=lfs -text
44
+ sample_images/precomputed/aesthetic_model2_vangogh/image_6.png filter=lfs diff=lfs merge=lfs -text
45
+ sample_images/precomputed/aesthetic_model2_vangogh/image_7.png filter=lfs diff=lfs merge=lfs -text
46
+ sample_images/precomputed/aesthetic_model2_vangogh/image_8.png filter=lfs diff=lfs merge=lfs -text
47
+ sample_images/precomputed/aesthetic_model2_vangogh/image_9.png filter=lfs diff=lfs merge=lfs -text
48
+ sample_images/precomputed/enfield3_winter_snow.png filter=lfs diff=lfs merge=lfs -text
49
+ sample_images/precomputed/enfield3_winter_snow/image_0.png filter=lfs diff=lfs merge=lfs -text
50
+ sample_images/precomputed/enfield3_winter_snow/image_1.png filter=lfs diff=lfs merge=lfs -text
51
+ sample_images/precomputed/enfield3_winter_snow/image_10.png filter=lfs diff=lfs merge=lfs -text
52
+ sample_images/precomputed/enfield3_winter_snow/image_2.png filter=lfs diff=lfs merge=lfs -text
53
+ sample_images/precomputed/enfield3_winter_snow/image_3.png filter=lfs diff=lfs merge=lfs -text
54
+ sample_images/precomputed/enfield3_winter_snow/image_4.png filter=lfs diff=lfs merge=lfs -text
55
+ sample_images/precomputed/enfield3_winter_snow/image_5.png filter=lfs diff=lfs merge=lfs -text
56
+ sample_images/precomputed/enfield3_winter_snow/image_6.png filter=lfs diff=lfs merge=lfs -text
57
+ sample_images/precomputed/enfield3_winter_snow/image_7.png filter=lfs diff=lfs merge=lfs -text
58
+ sample_images/precomputed/enfield3_winter_snow/image_8.png filter=lfs diff=lfs merge=lfs -text
59
+ sample_images/precomputed/enfield3_winter_snow/image_9.png filter=lfs diff=lfs merge=lfs -text
60
+ sample_images/precomputed/jackson_fluffy.png filter=lfs diff=lfs merge=lfs -text
61
+ sample_images/precomputed/jackson_fluffy/image_0.png filter=lfs diff=lfs merge=lfs -text
62
+ sample_images/precomputed/jackson_fluffy/image_1.png filter=lfs diff=lfs merge=lfs -text
63
+ sample_images/precomputed/jackson_fluffy/image_10.png filter=lfs diff=lfs merge=lfs -text
64
+ sample_images/precomputed/jackson_fluffy/image_11.png filter=lfs diff=lfs merge=lfs -text
65
+ sample_images/precomputed/jackson_fluffy/image_2.png filter=lfs diff=lfs merge=lfs -text
66
+ sample_images/precomputed/jackson_fluffy/image_3.png filter=lfs diff=lfs merge=lfs -text
67
+ sample_images/precomputed/jackson_fluffy/image_4.png filter=lfs diff=lfs merge=lfs -text
68
+ sample_images/precomputed/jackson_fluffy/image_5.png filter=lfs diff=lfs merge=lfs -text
69
+ sample_images/precomputed/jackson_fluffy/image_6.png filter=lfs diff=lfs merge=lfs -text
70
+ sample_images/precomputed/jackson_fluffy/image_7.png filter=lfs diff=lfs merge=lfs -text
71
+ sample_images/precomputed/jackson_fluffy/image_8.png filter=lfs diff=lfs merge=lfs -text
72
+ sample_images/precomputed/jackson_fluffy/image_9.png filter=lfs diff=lfs merge=lfs -text
73
+ sample_images/precomputed/light_lamp_blue_side.png filter=lfs diff=lfs merge=lfs -text
74
+ sample_images/precomputed/light_lamp_blue_side/image_0.png filter=lfs diff=lfs merge=lfs -text
75
+ sample_images/precomputed/light_lamp_blue_side/image_1.png filter=lfs diff=lfs merge=lfs -text
76
+ sample_images/precomputed/light_lamp_blue_side/image_10.png filter=lfs diff=lfs merge=lfs -text
77
+ sample_images/precomputed/light_lamp_blue_side/image_2.png filter=lfs diff=lfs merge=lfs -text
78
+ sample_images/precomputed/light_lamp_blue_side/image_3.png filter=lfs diff=lfs merge=lfs -text
79
+ sample_images/precomputed/light_lamp_blue_side/image_4.png filter=lfs diff=lfs merge=lfs -text
80
+ sample_images/precomputed/light_lamp_blue_side/image_5.png filter=lfs diff=lfs merge=lfs -text
81
+ sample_images/precomputed/light_lamp_blue_side/image_6.png filter=lfs diff=lfs merge=lfs -text
82
+ sample_images/precomputed/light_lamp_blue_side/image_7.png filter=lfs diff=lfs merge=lfs -text
83
+ sample_images/precomputed/light_lamp_blue_side/image_8.png filter=lfs diff=lfs merge=lfs -text
84
+ sample_images/precomputed/light_lamp_blue_side/image_9.png filter=lfs diff=lfs merge=lfs -text
85
+ sample_images/precomputed/venice1_grow_ivy.png filter=lfs diff=lfs merge=lfs -text
86
+ sample_images/precomputed/venice1_grow_ivy/image_0.png filter=lfs diff=lfs merge=lfs -text
87
+ sample_images/precomputed/venice1_grow_ivy/image_1.png filter=lfs diff=lfs merge=lfs -text
88
+ sample_images/precomputed/venice1_grow_ivy/image_10.png filter=lfs diff=lfs merge=lfs -text
89
+ sample_images/precomputed/venice1_grow_ivy/image_2.png filter=lfs diff=lfs merge=lfs -text
90
+ sample_images/precomputed/venice1_grow_ivy/image_3.png filter=lfs diff=lfs merge=lfs -text
91
+ sample_images/precomputed/venice1_grow_ivy/image_4.png filter=lfs diff=lfs merge=lfs -text
92
+ sample_images/precomputed/venice1_grow_ivy/image_5.png filter=lfs diff=lfs merge=lfs -text
93
+ sample_images/precomputed/venice1_grow_ivy/image_6.png filter=lfs diff=lfs merge=lfs -text
94
+ sample_images/precomputed/venice1_grow_ivy/image_7.png filter=lfs diff=lfs merge=lfs -text
95
+ sample_images/precomputed/venice1_grow_ivy/image_8.png filter=lfs diff=lfs merge=lfs -text
96
+ sample_images/precomputed/venice1_grow_ivy/image_9.png filter=lfs diff=lfs merge=lfs -text
model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model/__init__.py ADDED
File without changes
model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (168 Bytes). View file
 
model/__pycache__/sliders_model.cpython-310.pyc ADDED
Binary file (2.29 kB). View file
 
model/__pycache__/sliders_pipeline.cpython-310.pyc ADDED
Binary file (8.09 kB). View file
 
model/__pycache__/transformer_flux.cpython-310.pyc ADDED
Binary file (16.9 kB). View file
 
model/sliders_model.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class SliderProjector(torch.nn.Module):
6
+ def __init__(
7
+ self,
8
+ out_dim, # Dimension of the output token that the projector will generate
9
+ pe_dim, # The dimension of positional embedding that will be applied
10
+ n_layers = 4,
11
+ is_clip_input = True, # This function will check whether the clip embeddings are the input of the projector net or not
12
+ ):
13
+ super().__init__()
14
+ self.out_dim = out_dim
15
+ self.pe_dim = pe_dim
16
+ self.is_clip_input = is_clip_input
17
+
18
+ # Add the layers here in defining, assume n_layers is another parameter
19
+ layers = []
20
+ pe_extender_dim = 768
21
+
22
+ # if the clip embeddings are to be passed along with the input of the slider scalar value, we will increase the dimensions of the input of the projector net
23
+ if is_clip_input:
24
+ in_dim = pe_extender_dim + 768
25
+ else:
26
+ in_dim = pe_extender_dim
27
+
28
+ # iterating over the layers and accumulating the layers in a list for defining the model
29
+ for i in range(n_layers - 1):
30
+ layers.append(torch.nn.Linear(in_dim, out_dim))
31
+ layers.append(torch.nn.ReLU())
32
+ in_dim = out_dim
33
+ layers.append(torch.nn.Linear(in_dim, out_dim))
34
+
35
+ # a simple linear layer to extend the pe into a higher dimensional space
36
+ self.pe_extender = torch.nn.Linear(pe_dim, 768)
37
+ # then we will pass it through a projector network
38
+ self.projector = torch.nn.Sequential(*layers)
39
+
40
+ # A simple encoding function for the scalar input for a pe embedding
41
+ def posEnc(self, s):
42
+ pe = torch.stack([torch.sin(torch.pi * s), torch.cos(torch.pi * s)], dim=-1)
43
+ return pe
44
+
45
+ # A forward function that will take the input x and then projects it to a token embedding to condition the diffusion model.
46
+ def forward(self, s, clip_embeddings = None):
47
+ # Apply the positional embedding to the input scalar
48
+ x_pe = self.posEnc(s)
49
+ x_scale_embedding = self.pe_extender(x_pe) # (1, 768)
50
+
51
+ if clip_embeddings is not None: # if the clip input is passed, we will concatenated it with the scalar embeddings for processing
52
+ # print("clip embeddings shape: {}".format(clip_embeddings.shape))
53
+ x_combined_embedding = torch.cat([x_scale_embedding, clip_embeddings], dim=-1) # (1, 768 + 768)
54
+
55
+ x_proj = self.projector(x_combined_embedding)
56
+ # print("x proj shape: {}".format(x_proj.shape))
57
+ return x_proj
58
+
59
+
60
+ class SliderProjector_wo_clip(torch.nn.Module):
61
+ def __init__(
62
+ self,
63
+ out_dim, # Dimension of the output token that the projector will generate
64
+ pe_dim, # The dimension of positional embedding
65
+ n_layers = 4,
66
+ is_clip_input = False, # This function will check whether the clip embeddings are the input of the projector net or not
67
+ ):
68
+ super().__init__()
69
+ self.out_dim = out_dim
70
+ self.pe_dim = pe_dim
71
+
72
+ # Add the layers here in defining, assume n_layers is another parameter
73
+ layers = []
74
+ pe_extender_dim = 768
75
+
76
+ # extending the input dimenstion to the 768 with a linear layer to keep the dimensions consistent with other clip based model.
77
+ in_dim = pe_extender_dim
78
+
79
+ # iterating over the layers and accumulating the layers in a list for defining the model
80
+ for i in range(n_layers - 1):
81
+ layers.append(torch.nn.Linear(in_dim, out_dim))
82
+ layers.append(torch.nn.ReLU())
83
+ in_dim = out_dim
84
+ layers.append(torch.nn.Linear(in_dim, out_dim))
85
+
86
+ # adding a pe extender to have the same dimension as clip embeddings
87
+ self.pe_extender = torch.nn.Linear(pe_dim, 768)
88
+ # then we will pass it through a projector network
89
+ self.projector = torch.nn.Sequential(*layers)
90
+
91
+ def posEnc(self, s):
92
+ pe = torch.stack([torch.sin(torch.pi * s), torch.cos(torch.pi * s)], dim=-1)
93
+ return pe
94
+
95
+ # A forward function that will take the input x and then projects it to a token embedding to condition the diffusion model.
96
+ def forward(self, s):
97
+ x_pe = self.posEnc(s)
98
+ x_scale_embedding = self.pe_extender(x_pe)
99
+
100
+ x_proj = self.projector(x_scale_embedding)
101
+ return x_proj
102
+
model/sliders_pipeline.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # kontext_sliders_pipeline.py
2
+ import torch
3
+ from diffusers import FluxKontextPipeline # Base pipeline from Diffusers
4
+ import inspect
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
7
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
8
+ import numpy as np
9
+ from diffusers.pipelines.flux.pipeline_flux_kontext import *
10
+
11
+ # custom import for transformer models
12
+ from model.transformer_flux import FluxTransformer2DModelwithSliderConditioning
13
+
14
+
15
+ from diffusers.utils import (
16
+ USE_PEFT_BACKEND,
17
+ is_torch_xla_available,
18
+ logging,
19
+ replace_example_docstring,
20
+ scale_lora_layers,
21
+ unscale_lora_layers,
22
+ )
23
+
24
+ if is_torch_xla_available():
25
+ import torch_xla.core.xla_model as xm
26
+
27
+ XLA_AVAILABLE = True
28
+ else:
29
+ XLA_AVAILABLE = False
30
+
31
+
32
+ # defining the custom pipeline allowing for inference with the pretrained slider projector and the flux-kontext model.
33
+ class FluxKontextSliderPipeline(FluxKontextPipeline):
34
+ """
35
+ Custom pipeline extending FluxKontextPipeline with slider conditioning.
36
+ Minimal changes: Override __init__ to load slider_projector, and __call__ for slider-aware inference.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ scheduler: FlowMatchEulerDiscreteScheduler,
42
+ vae: AutoencoderKL,
43
+ text_encoder: CLIPTextModel,
44
+ tokenizer: CLIPTokenizer,
45
+ text_encoder_2: T5EncoderModel,
46
+ tokenizer_2: T5TokenizerFast,
47
+ transformer: FluxTransformer2DModelwithSliderConditioning,
48
+ image_encoder: CLIPVisionModelWithProjection = None,
49
+ feature_extractor: CLIPImageProcessor = None,
50
+ slider_projector=None, # the slider projector model loaded with the weights
51
+ text_condn: bool = False,
52
+ ):
53
+ # Calling the parent __init__ with the base arguments that are passed in the pipeline
54
+ super().__init__(
55
+ scheduler=scheduler,
56
+ vae=vae,
57
+ text_encoder=text_encoder,
58
+ tokenizer=tokenizer,
59
+ text_encoder_2=text_encoder_2,
60
+ tokenizer_2=tokenizer_2,
61
+ transformer=transformer,
62
+ image_encoder=image_encoder,
63
+ feature_extractor=feature_extractor,
64
+ )
65
+
66
+ device = self._execution_device
67
+ # Minimal addition: Load your custom slider_projector
68
+ self.slider_projector = slider_projector
69
+
70
+ self.text_condn = text_condn # whether we are conditioning in the text space or the modulation space
71
+ self.slider_projector.eval() # Set to eval mode for inference
72
+
73
+ def __call__(
74
+ self,
75
+ image: Optional[PipelineImageInput] = None,
76
+ prompt: Union[str, List[str]] = None,
77
+ prompt_2: Optional[Union[str, List[str]]] = None,
78
+ negative_prompt: Union[str, List[str]] = None,
79
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
80
+ true_cfg_scale: float = 1.0,
81
+ height: Optional[int] = None,
82
+ width: Optional[int] = None,
83
+ num_inference_steps: int = 28,
84
+ sigmas: Optional[List[float]] = None,
85
+ guidance_scale: float = 3.5,
86
+ num_images_per_prompt: Optional[int] = 1,
87
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
88
+ latents: Optional[torch.FloatTensor] = None,
89
+ prompt_embeds: Optional[torch.FloatTensor] = None,
90
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
91
+ ip_adapter_image: Optional[PipelineImageInput] = None,
92
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
93
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
94
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
95
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
96
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
97
+ output_type: Optional[str] = "pil",
98
+ return_dict: bool = True,
99
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
100
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
101
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
102
+ max_sequence_length: int = 512,
103
+ max_area: int = 1024**2,
104
+ _auto_resize: bool = True,
105
+ # slider values as additional input for the pipeline ------------- #
106
+ # slider projector is already initialized in the parent call so that we can call it to obtain the embeddings for the sliders
107
+ text_condn: bool = False,
108
+ modulation_condn: bool = False,
109
+ slider_value: Optional[torch.FloatTensor] = None,
110
+ is_clip_input: bool = False, # This is to check whether the slider projector takes the clip text embedding as input for modulating
111
+ ):
112
+ # small modification to keep all the values on the same device, and the device is passed along with the pipeline to the model.
113
+
114
+ height = height or self.default_sample_size * self.vae_scale_factor
115
+ width = width or self.default_sample_size * self.vae_scale_factor
116
+
117
+ # print("vae scale factor: {}".format(self.vae_scale_factor))
118
+ # print("default sample size: {}".format(self.default_sample_size))
119
+ # print("default sample size: height: {}, width: {}".format(height, width))
120
+
121
+ original_height, original_width = height, width
122
+ aspect_ratio = width / height
123
+ width = round((max_area * aspect_ratio) ** 0.5)
124
+ height = round((max_area / aspect_ratio) ** 0.5)
125
+
126
+ multiple_of = self.vae_scale_factor * 2
127
+ width = width // multiple_of * multiple_of
128
+ height = height // multiple_of * multiple_of
129
+ # print("after width and height quantized: height: {}, width: {}".format(height, width))
130
+
131
+
132
+ # not checking for the height and width are matching to the predefined dimensions for inferences.
133
+ # if height != original_height or width != original_width:
134
+ # print("height and width are not matching the original dimensions ..")
135
+
136
+ # 1. Check inputs. Raise error if not correct
137
+ self.check_inputs(
138
+ prompt,
139
+ prompt_2,
140
+ height,
141
+ width,
142
+ negative_prompt=negative_prompt,
143
+ negative_prompt_2=negative_prompt_2,
144
+ prompt_embeds=prompt_embeds,
145
+ negative_prompt_embeds=negative_prompt_embeds,
146
+ pooled_prompt_embeds=pooled_prompt_embeds,
147
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
148
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
149
+ max_sequence_length=max_sequence_length,
150
+ )
151
+
152
+ self._guidance_scale = guidance_scale
153
+ self._joint_attention_kwargs = joint_attention_kwargs
154
+ self._current_timestep = None
155
+ self._interrupt = False
156
+
157
+ # 2. Define call parameters
158
+ if prompt is not None and isinstance(prompt, str):
159
+ batch_size = 1
160
+ elif prompt is not None and isinstance(prompt, list):
161
+ batch_size = len(prompt)
162
+ else:
163
+ batch_size = len(prompt_embeds)
164
+
165
+ device = self._execution_device
166
+ # print("execution device: {}".format(device))
167
+
168
+ lora_scale = (
169
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
170
+ )
171
+ has_neg_prompt = negative_prompt is not None or (
172
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
173
+ )
174
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
175
+ (
176
+ prompt_embeds,
177
+ pooled_prompt_embeds,
178
+ text_ids,
179
+ ) = self.encode_prompt(
180
+ prompt=prompt,
181
+ prompt_2=prompt_2,
182
+ prompt_embeds=prompt_embeds,
183
+ pooled_prompt_embeds=pooled_prompt_embeds,
184
+ device=device,
185
+ num_images_per_prompt=num_images_per_prompt,
186
+ max_sequence_length=max_sequence_length,
187
+ lora_scale=lora_scale,
188
+ )
189
+ if do_true_cfg:
190
+ (
191
+ negative_prompt_embeds,
192
+ negative_pooled_prompt_embeds,
193
+ negative_text_ids,
194
+ ) = self.encode_prompt(
195
+ prompt=negative_prompt,
196
+ prompt_2=negative_prompt_2,
197
+ prompt_embeds=negative_prompt_embeds,
198
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
199
+ device=device,
200
+ num_images_per_prompt=num_images_per_prompt,
201
+ max_sequence_length=max_sequence_length,
202
+ lora_scale=lora_scale,
203
+ )
204
+
205
+ # 3. Preprocess image ---------- this is the older preprocessing function that is forcing the images to be of the size 1024x1024, but we are training with 512x512 so changing the output to be of the same dimensions.
206
+ # if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
207
+ # img = image[0] if isinstance(image, list) else image
208
+ # image_height, image_width = self.image_processor.get_default_height_width(img)
209
+ # aspect_ratio = image_width / image_height
210
+ # if _auto_resize:
211
+ # # Kontext is trained on specific resolutions, using one of them is recommended
212
+ # _, image_width, image_height = min(
213
+ # (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
214
+ # )
215
+ # image_width = image_width // multiple_of * multiple_of
216
+ # image_height = image_height // multiple_of * multiple_of
217
+ # image = self.image_processor.resize(image, image_height, image_width)
218
+ # image = self.image_processor.preprocess(image, image_height, image_width)
219
+
220
+ # 3.1 Custom image preprocessing module that will reshape the images to the given input dimensions
221
+ # overriding the height and the width for the original model components as we have a fixed size for images in our dataset.
222
+ height = original_height
223
+ width = original_width
224
+
225
+ image = self.image_processor.resize(image, height, width)
226
+ image = self.image_processor.preprocess(image, height, width)
227
+ # print("image shape after preprocessing: {}".format(image.shape))
228
+
229
+ # 3. -------------------------------------- Preparing the slider values -------------------------------------- #
230
+ # This is a correct way to check the device for a tensor and a model in PyTorch.
231
+ # print(f"slider_value device: {slider_value.device}") # tensor device
232
+ # print(f"slider_projector device: {next(self.slider_projector.parameters()).device}") # model device
233
+
234
+ # if clip input is enabled, we will compute the embeddings using both the slider values and the pooled prompt embeddings
235
+ if (is_clip_input):
236
+ # TODO: This may not work with larget batch size then 1, please validate this once then run the output.
237
+ # this takes vector as input as the slider_value is also a list
238
+ # Ensure pooled_prompt_embeds is a tensor of shape [1, 1, ...] (one higher dimension)
239
+ pooled_prompt_embeds_tensor = torch.tensor(pooled_prompt_embeds).unsqueeze(0).to(device)
240
+ slider_value = slider_value.to(device)
241
+
242
+ self.slider_projector = self.slider_projector.to(device)
243
+ # print("pooled prompt device: {}".format(pooled_prompt_embeds_tensor.device))
244
+ # print("slider value device: {}".format(slider_value.device))
245
+ # print("slider projector device: {}".format(next(self.slider_projector.parameters()).device))
246
+
247
+ slider_embeddings = self.slider_projector(slider_value, pooled_prompt_embeds_tensor).to(device)
248
+ else:
249
+ slider_embeddings = self.slider_projector(slider_value).to(device)
250
+
251
+
252
+ # print("slider embeddings device: {}".format(slider_embeddings.device))
253
+ # multiplying the slider embeddings with a random value to check whether there is any effect of changing the slider in the input
254
+ # slider_embeddings = slider_embeddings * (np.random.rand() * 4 - 2)
255
+
256
+ # print("slider embeddings norm: {}".format(slider_embeddings.norm()))
257
+ # print("slider value inside the pipeline: {}".format(slider_value))
258
+ # print("slider embeddings: {}".format(slider_embeddings.shape)) # (1, 1, 64)
259
+ slider_id = torch.tensor([0,0,2]).reshape(1,3).to(device)
260
+
261
+ # replicating the same slider embeddings with n_repeat times
262
+ n_repeats = 1
263
+ repeated_slider_token = slider_embeddings.repeat(1, n_repeats, 1)
264
+ repeated_slider_id = slider_id.repeat(n_repeats, 1)
265
+
266
+ # ------------------------------- concatenating the slider embeddings with the text embeddings --------------- #
267
+ # if we are conditioning in the text space then will concatenate the slider tokens to the conditioning
268
+
269
+ if text_condn:
270
+ print("using text conditioning ...")
271
+ extended_text_ids = torch.cat([text_ids, repeated_slider_id], dim=0)
272
+ extended_prompt_embeds = torch.cat([prompt_embeds, repeated_slider_token], dim=1)
273
+ else:
274
+ extended_text_ids = text_ids
275
+ extended_prompt_embeds = prompt_embeds
276
+
277
+ if modulation_condn:
278
+ modulation_embeddings = repeated_slider_token
279
+ else:
280
+ modulation_embeddings = None
281
+
282
+ # print("concatenated text ids shape: {}".format(extended_text_ids.shape)) # (640, 3)
283
+ # print("concatenated prompt embeds shape: {}".format(extended_prompt_embeds.shape)) # (1, 640, 4096)
284
+
285
+ # print("slider id: {}".format(slider_id.shape)) # (1, 3)
286
+ #--------------------- defined the slider components that I will use along with the other inputs to perform the forward pass of the model. ---------------------#
287
+
288
+ # 4. Prepare latent variables
289
+ num_channels_latents = self.transformer.config.in_channels // 4
290
+ latents, image_latents, latent_ids, image_ids = self.prepare_latents(
291
+ image,
292
+ batch_size * num_images_per_prompt,
293
+ num_channels_latents,
294
+ height,
295
+ width,
296
+ prompt_embeds.dtype,
297
+ device,
298
+ generator,
299
+ latents,
300
+ )
301
+ if image_ids is not None:
302
+ # latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
303
+ # TODO: Verify the shapes here, adding the slider id along with the ids for the input and target images
304
+ # print("original latent ids: {}".format(latent_ids.shape))
305
+ ## --- not using the slider id along with the visual tokens, we are adding them along with the text tokens --- ##
306
+ # latent_ids = torch.cat([latent_ids, image_ids, slider_id], dim=0)
307
+
308
+ # --- using the standard image and text latent conditioning and not adding the slider ids in the model --- ##
309
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0)
310
+
311
+ # print("latent ids after concatenation: {}".format(latent_ids.shape))
312
+
313
+ # 5. Prepare timesteps
314
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
315
+ image_seq_len = latents.shape[1]
316
+ mu = calculate_shift(
317
+ image_seq_len,
318
+ self.scheduler.config.get("base_image_seq_len", 256),
319
+ self.scheduler.config.get("max_image_seq_len", 4096),
320
+ self.scheduler.config.get("base_shift", 0.5),
321
+ self.scheduler.config.get("max_shift", 1.15),
322
+ )
323
+ timesteps, num_inference_steps = retrieve_timesteps(
324
+ self.scheduler,
325
+ num_inference_steps,
326
+ device,
327
+ sigmas=sigmas,
328
+ mu=mu,
329
+ )
330
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
331
+ self._num_timesteps = len(timesteps)
332
+
333
+ # handle guidance
334
+ if self.transformer.config.guidance_embeds:
335
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
336
+ guidance = guidance.expand(latents.shape[0])
337
+ else:
338
+ guidance = None
339
+
340
+ # -------------- Logic for ip adapter, we can remove this ----------------------- #
341
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
342
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
343
+ ):
344
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
345
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
346
+
347
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
348
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
349
+ ):
350
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
351
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
352
+
353
+ if self.joint_attention_kwargs is None:
354
+ self._joint_attention_kwargs = {}
355
+
356
+ image_embeds = None
357
+ negative_image_embeds = None
358
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
359
+ image_embeds = self.prepare_ip_adapter_image_embeds(
360
+ ip_adapter_image,
361
+ ip_adapter_image_embeds,
362
+ device,
363
+ batch_size * num_images_per_prompt,
364
+ )
365
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
366
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
367
+ negative_ip_adapter_image,
368
+ negative_ip_adapter_image_embeds,
369
+ device,
370
+ batch_size * num_images_per_prompt,
371
+ )
372
+
373
+ # 6. Denoising loop
374
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
375
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
376
+ self.scheduler.set_begin_index(0)
377
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
378
+ for i, t in enumerate(timesteps):
379
+ if self.interrupt:
380
+ continue
381
+
382
+ self._current_timestep = t
383
+ if image_embeds is not None:
384
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
385
+
386
+ # stacking the latents for the generated latent and the input image latent
387
+ latent_model_input = latents
388
+ if image_latents is not None:
389
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
390
+
391
+ # print("latent model shape after concatenation: {}".format(latent_model_input.shape))
392
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
393
+
394
+ noise_pred = self.transformer(
395
+ hidden_states=latent_model_input,
396
+ timestep=timestep / 1000,
397
+ guidance=guidance,
398
+ pooled_projections=pooled_prompt_embeds,
399
+ encoder_hidden_states=extended_prompt_embeds,
400
+ txt_ids=extended_text_ids,
401
+ img_ids=latent_ids,
402
+ joint_attention_kwargs=self.joint_attention_kwargs,
403
+ return_dict=False,
404
+ ## adding the modulation token, if we are working with modulation space conditioning
405
+ modulation_embeddings=modulation_embeddings, # passing the modulation embeddings that will be defined based on whether the modulation inference is enabled or not
406
+ )[0]
407
+ noise_pred = noise_pred[:, : latents.size(1)]
408
+
409
+ if do_true_cfg:
410
+ if negative_image_embeds is not None:
411
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
412
+ neg_noise_pred = self.transformer(
413
+ hidden_states=latent_model_input,
414
+ timestep=timestep / 1000,
415
+ guidance=guidance,
416
+ pooled_projections=negative_pooled_prompt_embeds,
417
+ encoder_hidden_states=negative_prompt_embeds,
418
+ txt_ids=negative_text_ids,
419
+ img_ids=latent_ids,
420
+ joint_attention_kwargs=self.joint_attention_kwargs,
421
+ return_dict=False,
422
+ modulation_embeddings=modulation_embeddings,
423
+ )[0]
424
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
425
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
426
+
427
+ # compute the previous noisy sample x_t -> x_t-1
428
+ latents_dtype = latents.dtype
429
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
430
+
431
+ if latents.dtype != latents_dtype:
432
+ if torch.backends.mps.is_available():
433
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
434
+ latents = latents.to(latents_dtype)
435
+
436
+ if callback_on_step_end is not None:
437
+ callback_kwargs = {}
438
+ for k in callback_on_step_end_tensor_inputs:
439
+ callback_kwargs[k] = locals()[k]
440
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
441
+
442
+ latents = callback_outputs.pop("latents", latents)
443
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
444
+
445
+ # call the callback, if provided
446
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
447
+ progress_bar.update()
448
+
449
+ if XLA_AVAILABLE:
450
+ xm.mark_step()
451
+
452
+ self._current_timestep = None
453
+
454
+ if output_type == "latent":
455
+ image = latents
456
+ else:
457
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
458
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
459
+ image = self.vae.decode(latents, return_dict=False)[0]
460
+ image = self.image_processor.postprocess(image, output_type=output_type)
461
+
462
+ # Offload all models
463
+ self.maybe_free_model_hooks()
464
+
465
+ if not return_dict:
466
+ return (image,)
467
+
468
+ return FluxPipelineOutput(images=image)
model/transformer_flux.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
24
+ from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
25
+ from diffusers.utils.import_utils import is_torch_npu_available
26
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
27
+ from diffusers.models.attention import FeedForward
28
+ from diffusers.models.attention_processor import (
29
+ Attention,
30
+ AttentionProcessor,
31
+ FluxAttnProcessor2_0,
32
+ FluxAttnProcessor2_0_NPU,
33
+ FusedFluxAttnProcessor2_0,
34
+ )
35
+ from diffusers.models.cache_utils import CacheMixin
36
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
37
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
38
+ from diffusers.models.modeling_utils import ModelMixin
39
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
40
+
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+ @maybe_allow_in_graph
46
+ class FluxSingleTransformerBlock(nn.Module):
47
+ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
48
+ super().__init__()
49
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
50
+
51
+ self.norm = AdaLayerNormZeroSingle(dim)
52
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
53
+ self.act_mlp = nn.GELU(approximate="tanh")
54
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
55
+
56
+ if is_torch_npu_available():
57
+ deprecation_message = (
58
+ "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
59
+ "should be set explicitly using the `set_attn_processor` method."
60
+ )
61
+ deprecate("npu_processor", "0.34.0", deprecation_message)
62
+ processor = FluxAttnProcessor2_0_NPU()
63
+ else:
64
+ processor = FluxAttnProcessor2_0()
65
+
66
+ self.attn = Attention(
67
+ query_dim=dim,
68
+ cross_attention_dim=None,
69
+ dim_head=attention_head_dim,
70
+ heads=num_attention_heads,
71
+ out_dim=dim,
72
+ bias=True,
73
+ processor=processor,
74
+ qk_norm="rms_norm",
75
+ eps=1e-6,
76
+ pre_only=True,
77
+ )
78
+
79
+ def forward(
80
+ self,
81
+ hidden_states: torch.Tensor,
82
+ temb: torch.Tensor,
83
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
84
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
85
+ ) -> torch.Tensor:
86
+ residual = hidden_states
87
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
88
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
89
+ joint_attention_kwargs = joint_attention_kwargs or {}
90
+ attn_output = self.attn(
91
+ hidden_states=norm_hidden_states,
92
+ image_rotary_emb=image_rotary_emb,
93
+ **joint_attention_kwargs,
94
+ )
95
+
96
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
97
+ gate = gate.unsqueeze(1)
98
+ hidden_states = gate * self.proj_out(hidden_states)
99
+ hidden_states = residual + hidden_states
100
+ if hidden_states.dtype == torch.float16:
101
+ hidden_states = hidden_states.clip(-65504, 65504)
102
+
103
+ return hidden_states
104
+
105
+
106
+ @maybe_allow_in_graph
107
+ class FluxTransformerBlock(nn.Module):
108
+ def __init__(
109
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
110
+ ):
111
+ super().__init__()
112
+
113
+ self.norm1 = AdaLayerNormZero(dim)
114
+ self.norm1_context = AdaLayerNormZero(dim)
115
+
116
+ self.attn = Attention(
117
+ query_dim=dim,
118
+ cross_attention_dim=None,
119
+ added_kv_proj_dim=dim,
120
+ dim_head=attention_head_dim,
121
+ heads=num_attention_heads,
122
+ out_dim=dim,
123
+ context_pre_only=False,
124
+ bias=True,
125
+ processor=FluxAttnProcessor2_0(),
126
+ qk_norm=qk_norm,
127
+ eps=eps,
128
+ )
129
+
130
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
131
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
132
+
133
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
134
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
135
+
136
+ def forward(
137
+ self,
138
+ hidden_states: torch.Tensor,
139
+ encoder_hidden_states: torch.Tensor,
140
+ temb: torch.Tensor,
141
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
142
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
143
+ ## ---- adding the modulation conditioning vector for controlling the strength ---- ##
144
+ modulation_condn: Optional[torch.Tensor] = None,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+
147
+ # add logic here for conditioning
148
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
149
+
150
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
151
+ encoder_hidden_states, emb=temb
152
+ )
153
+
154
+ # If modulation conditioning is passed, then we will use that to adjust the scale and shift parameters otherwise we will proceed with regular modulation function
155
+ if modulation_condn is not None:
156
+ modulation_condn = modulation_condn.squeeze(1)
157
+
158
+ # chunking the modulation space here
159
+ modulation_scale, modulation_shift = modulation_condn.chunk(2, dim=1) # dividing the output into two parts, one for scale and another one for shift.
160
+ # print("modulation condn shape: {}".format(modulation_condn.shape)) # [1, out_dim]
161
+
162
+ # adding a delta shift to the shift modulation vector
163
+ c_shift_mlp = c_shift_mlp + modulation_shift
164
+ # adding a delta scale to the scale modulation vector
165
+ c_scale_mlp = c_scale_mlp + modulation_scale
166
+
167
+ joint_attention_kwargs = joint_attention_kwargs or {}
168
+ # Attention.
169
+ attention_outputs = self.attn(
170
+ hidden_states=norm_hidden_states,
171
+ encoder_hidden_states=norm_encoder_hidden_states,
172
+ image_rotary_emb=image_rotary_emb,
173
+ **joint_attention_kwargs,
174
+ )
175
+
176
+ if len(attention_outputs) == 2:
177
+ attn_output, context_attn_output = attention_outputs
178
+ elif len(attention_outputs) == 3:
179
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
180
+
181
+ # Process attention outputs for the `hidden_states`.
182
+ attn_output = gate_msa.unsqueeze(1) * attn_output
183
+ hidden_states = hidden_states + attn_output
184
+
185
+ norm_hidden_states = self.norm2(hidden_states)
186
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
187
+
188
+ ff_output = self.ff(norm_hidden_states)
189
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
190
+
191
+ hidden_states = hidden_states + ff_output
192
+ if len(attention_outputs) == 3:
193
+ hidden_states = hidden_states + ip_attn_output
194
+
195
+ # Process attention outputs for the `encoder_hidden_states`.
196
+
197
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
198
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
199
+
200
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
201
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
202
+
203
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
204
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
205
+ if encoder_hidden_states.dtype == torch.float16:
206
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
207
+
208
+ return encoder_hidden_states, hidden_states
209
+
210
+
211
+ class FluxTransformer2DModelwithSliderConditioning(
212
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
213
+ ):
214
+ """
215
+ The Transformer model introduced in Flux.
216
+
217
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
218
+
219
+ Args:
220
+ patch_size (`int`, defaults to `1`):
221
+ Patch size to turn the input data into small patches.
222
+ in_channels (`int`, defaults to `64`):
223
+ The number of channels in the input.
224
+ out_channels (`int`, *optional*, defaults to `None`):
225
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
226
+ num_layers (`int`, defaults to `19`):
227
+ The number of layers of dual stream DiT blocks to use.
228
+ num_single_layers (`int`, defaults to `38`):
229
+ The number of layers of single stream DiT blocks to use.
230
+ attention_head_dim (`int`, defaults to `128`):
231
+ The number of dimensions to use for each attention head.
232
+ num_attention_heads (`int`, defaults to `24`):
233
+ The number of attention heads to use.
234
+ joint_attention_dim (`int`, defaults to `4096`):
235
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
236
+ `encoder_hidden_states`).
237
+ pooled_projection_dim (`int`, defaults to `768`):
238
+ The number of dimensions to use for the pooled projection.
239
+ guidance_embeds (`bool`, defaults to `False`):
240
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
241
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
242
+ The dimensions to use for the rotary positional embeddings.
243
+ """
244
+
245
+ _supports_gradient_checkpointing = True
246
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
247
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
248
+ _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
249
+
250
+ @register_to_config
251
+ def __init__(
252
+ self,
253
+ patch_size: int = 1,
254
+ in_channels: int = 64,
255
+ out_channels: Optional[int] = None,
256
+ num_layers: int = 19,
257
+ num_single_layers: int = 38,
258
+ attention_head_dim: int = 128,
259
+ num_attention_heads: int = 24,
260
+ joint_attention_dim: int = 4096,
261
+ pooled_projection_dim: int = 768,
262
+ guidance_embeds: bool = False,
263
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
264
+ ):
265
+ super().__init__()
266
+ self.out_channels = out_channels or in_channels
267
+ self.inner_dim = num_attention_heads * attention_head_dim
268
+
269
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
270
+
271
+ text_time_guidance_cls = (
272
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
273
+ )
274
+ self.time_text_embed = text_time_guidance_cls(
275
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
276
+ )
277
+
278
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
279
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
280
+
281
+ self.transformer_blocks = nn.ModuleList(
282
+ [ # we will add conditioning logic in this block for training with modulation space
283
+ FluxTransformerBlock(
284
+ dim=self.inner_dim,
285
+ num_attention_heads=num_attention_heads,
286
+ attention_head_dim=attention_head_dim,
287
+ )
288
+ for _ in range(num_layers)
289
+ ]
290
+ )
291
+
292
+ self.single_transformer_blocks = nn.ModuleList(
293
+ [
294
+ FluxSingleTransformerBlock(
295
+ dim=self.inner_dim,
296
+ num_attention_heads=num_attention_heads,
297
+ attention_head_dim=attention_head_dim,
298
+ )
299
+ for _ in range(num_single_layers)
300
+ ]
301
+ )
302
+
303
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
304
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
305
+
306
+ self.gradient_checkpointing = False
307
+
308
+ @property
309
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
310
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
311
+ r"""
312
+ Returns:
313
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
314
+ indexed by its weight name.
315
+ """
316
+ # set recursively
317
+ processors = {}
318
+
319
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
320
+ if hasattr(module, "get_processor"):
321
+ processors[f"{name}.processor"] = module.get_processor()
322
+
323
+ for sub_name, child in module.named_children():
324
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
325
+
326
+ return processors
327
+
328
+ for name, module in self.named_children():
329
+ fn_recursive_add_processors(name, module, processors)
330
+
331
+ return processors
332
+
333
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
334
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
335
+ r"""
336
+ Sets the attention processor to use to compute attention.
337
+
338
+ Parameters:
339
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
340
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
341
+ for **all** `Attention` layers.
342
+
343
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
344
+ processor. This is strongly recommended when setting trainable attention processors.
345
+
346
+ """
347
+ count = len(self.attn_processors.keys())
348
+
349
+ if isinstance(processor, dict) and len(processor) != count:
350
+ raise ValueError(
351
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
352
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
353
+ )
354
+
355
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
356
+ if hasattr(module, "set_processor"):
357
+ if not isinstance(processor, dict):
358
+ module.set_processor(processor)
359
+ else:
360
+ module.set_processor(processor.pop(f"{name}.processor"))
361
+
362
+ for sub_name, child in module.named_children():
363
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
364
+
365
+ for name, module in self.named_children():
366
+ fn_recursive_attn_processor(name, module, processor)
367
+
368
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
369
+ def fuse_qkv_projections(self):
370
+ """
371
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
372
+ are fused. For cross-attention modules, key and value projection matrices are fused.
373
+
374
+ <Tip warning={true}>
375
+
376
+ This API is 🧪 experimental.
377
+
378
+ </Tip>
379
+ """
380
+ self.original_attn_processors = None
381
+
382
+ for _, attn_processor in self.attn_processors.items():
383
+ if "Added" in str(attn_processor.__class__.__name__):
384
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
385
+
386
+ self.original_attn_processors = self.attn_processors
387
+
388
+ for module in self.modules():
389
+ if isinstance(module, Attention):
390
+ module.fuse_projections(fuse=True)
391
+
392
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
393
+
394
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
395
+ def unfuse_qkv_projections(self):
396
+ """Disables the fused QKV projection if enabled.
397
+
398
+ <Tip warning={true}>
399
+
400
+ This API is 🧪 experimental.
401
+
402
+ </Tip>
403
+
404
+ """
405
+ if self.original_attn_processors is not None:
406
+ self.set_attn_processor(self.original_attn_processors)
407
+
408
+ def forward(
409
+ self,
410
+ hidden_states: torch.Tensor,
411
+ encoder_hidden_states: torch.Tensor = None,
412
+ pooled_projections: torch.Tensor = None,
413
+ timestep: torch.LongTensor = None,
414
+ img_ids: torch.Tensor = None,
415
+ txt_ids: torch.Tensor = None,
416
+ guidance: torch.Tensor = None,
417
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
418
+ controlnet_block_samples=None,
419
+ controlnet_single_block_samples=None,
420
+ return_dict: bool = True,
421
+ controlnet_blocks_repeat: bool = False,
422
+ # adding a modulation conditioning, where an embedding is passed that can be used to modulate features of diffusion model
423
+ modulation_embeddings: Optional[torch.Tensor] = None,
424
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
425
+ """
426
+ The [`FluxTransformer2DModelwithSliderConditioning`] forward method.
427
+
428
+ Args:
429
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
430
+ Input `hidden_states`.
431
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
432
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
433
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
434
+ from the embeddings of input conditions.
435
+ timestep ( `torch.LongTensor`):
436
+ Used to indicate denoising step.
437
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
438
+ A list of tensors that if specified are added to the residuals of transformer blocks.
439
+ joint_attention_kwargs (`dict`, *optional*):
440
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
441
+ `self.processor` in
442
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
443
+ return_dict (`bool`, *optional*, defaults to `True`):
444
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
445
+ tuple.
446
+
447
+ Returns:
448
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
449
+ `tuple` where the first element is the sample tensor.
450
+ """
451
+
452
+ # if modulation_embeddings is not None:
453
+ # print("working with modulation space conditioning ...")
454
+ # print("modulation condn in main transformer call: {}".format(modulation_embeddings.shape))
455
+
456
+
457
+ if joint_attention_kwargs is not None:
458
+ joint_attention_kwargs = joint_attention_kwargs.copy()
459
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
460
+ else:
461
+ lora_scale = 1.0
462
+
463
+ if USE_PEFT_BACKEND:
464
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
465
+ scale_lora_layers(self, lora_scale)
466
+ else:
467
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
468
+ logger.warning(
469
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
470
+ )
471
+
472
+ hidden_states = self.x_embedder(hidden_states)
473
+
474
+ timestep = timestep.to(hidden_states.dtype) * 1000
475
+ if guidance is not None:
476
+ guidance = guidance.to(hidden_states.dtype) * 1000
477
+
478
+ temb = (
479
+ self.time_text_embed(timestep, pooled_projections)
480
+ if guidance is None
481
+ else self.time_text_embed(timestep, guidance, pooled_projections)
482
+ )
483
+
484
+ # print("temb shape: {}".format(temb.shape)) # [1, 3072]
485
+
486
+ # ------------------------ Logic to add the predicted embedding at the root of the modulation branch ------------------------- #
487
+ # modulation_embeddings = modulation_embeddings.squeeze(1)
488
+ # scale_factor = 10
489
+ # temb = temb + modulation_embeddings * scale_factor
490
+ # ---------------------------------------------------------------------------------------------------------------------------- #
491
+
492
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
493
+
494
+ if txt_ids.ndim == 3:
495
+ logger.warning(
496
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
497
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
498
+ )
499
+ txt_ids = txt_ids[0]
500
+ if img_ids.ndim == 3:
501
+ logger.warning(
502
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
503
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
504
+ )
505
+ img_ids = img_ids[0]
506
+
507
+ ids = torch.cat((txt_ids, img_ids), dim=0)
508
+ image_rotary_emb = self.pos_embed(ids)
509
+
510
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
511
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
512
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
513
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
514
+
515
+
516
+ # Iterating over the transformer blocks that process the text and image conditionings separately and then will be combined later
517
+ for index_block, block in enumerate(self.transformer_blocks):
518
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
519
+
520
+ # TODO: This we have to test and validate later, there is a possibility there is a bug in this and will not work for us.
521
+ # # copied from syncd codebase for gradient checkpoint with new arugments
522
+ # def create_custom_forward(module, return_dict=None):
523
+ # def custom_forward(*inputs):
524
+ # if return_dict is not None:
525
+ # return module(*inputs, return_dict=return_dict)
526
+ # else:
527
+ # return module(*inputs)
528
+
529
+ # return custom_forward
530
+
531
+ # new_kwargs = {
532
+ # "modulation_condn": modulation_condn,
533
+ # }
534
+
535
+ # This line applies gradient checkpointing to the transformer block, allowing for reduced memory usage during training by recomputing intermediate activations in the backward pass.
536
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
537
+ block,
538
+ hidden_states,
539
+ encoder_hidden_states,
540
+ temb,
541
+ image_rotary_emb,
542
+ )
543
+
544
+ else:
545
+ # adding the modulation conditioning vector in the separate transformer blocks that will use it for adjusting the features
546
+ encoder_hidden_states, hidden_states = block(
547
+ hidden_states=hidden_states,
548
+ encoder_hidden_states=encoder_hidden_states,
549
+ temb=temb, # the temb vector is modified and a delta based on the new conditioning is added to it.
550
+ image_rotary_emb=image_rotary_emb,
551
+ joint_attention_kwargs=joint_attention_kwargs,
552
+ # adding the modulation conditioning vector that can control the editing quality
553
+ modulation_condn = modulation_embeddings, # modulation_embeddings, | Not passing the modulation conditioning as it is already incroported in the temb vector
554
+ )
555
+
556
+ # controlnet residual
557
+ if controlnet_block_samples is not None:
558
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
559
+ interval_control = int(np.ceil(interval_control))
560
+ # For Xlabs ControlNet.
561
+ if controlnet_blocks_repeat:
562
+ hidden_states = (
563
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
564
+ )
565
+ else:
566
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
567
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
568
+
569
+ # single stream transformer blocks where the processing is happening in a single pass for the text and the image conditionings.
570
+ for index_block, block in enumerate(self.single_transformer_blocks):
571
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
572
+ hidden_states = self._gradient_checkpointing_func(
573
+ block,
574
+ hidden_states,
575
+ temb,
576
+ image_rotary_emb,
577
+ )
578
+
579
+ else:
580
+ hidden_states = block(
581
+ hidden_states=hidden_states,
582
+ temb=temb,
583
+ image_rotary_emb=image_rotary_emb,
584
+ joint_attention_kwargs=joint_attention_kwargs,
585
+ )
586
+
587
+ # controlnet residual
588
+ if controlnet_single_block_samples is not None:
589
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
590
+ interval_control = int(np.ceil(interval_control))
591
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
592
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
593
+ + controlnet_single_block_samples[index_block // interval_control]
594
+ )
595
+
596
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
597
+
598
+ hidden_states = self.norm_out(hidden_states, temb)
599
+ output = self.proj_out(hidden_states)
600
+
601
+ if USE_PEFT_BACKEND:
602
+ # remove `lora_scale` from each PEFT layer
603
+ unscale_lora_layers(self, lora_scale)
604
+
605
+ if not return_dict:
606
+ return (output,)
607
+
608
+ return Transformer2DModelOutput(sample=output)
model_weights/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model_weights/pytorch_lora_weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5746f688412bbc53ca3eb4042c35021d2dc055f4db8bb36bc8f2368e7ec9ecb1
3
+ size 22505648
model_weights/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e391836cc75c0818205e6a36b62a71c6257a78138fb911e2fc7368f419fa870
3
+ size 16513
model_weights/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48f9175e41bb943ae6b74237893b92279cffc2863e7dc98c14140fdb00cc418f
3
+ size 1465
model_weights/slider_projector.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d7ad270dd8ddf2d120ee493171cc94cab3edddfb7d320a53906c8792a2573b8
3
+ size 245425109
sample_images/precomputed/aesthetic_model2_vangogh.png ADDED

Git LFS Details

  • SHA256: 40a5c7dcd1eeb2cdcf4a59ea131565746efda7152b6c8965af46a8feff982fcd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
sample_images/precomputed/aesthetic_model2_vangogh/image_0.png ADDED

Git LFS Details

  • SHA256: ccee067e0c4228689976afc5c5dafcf67656b1556199f31a5c9a4b6f0c255dca
  • Pointer size: 132 Bytes
  • Size of remote file: 1.29 MB
sample_images/precomputed/aesthetic_model2_vangogh/image_1.png ADDED

Git LFS Details

  • SHA256: 70ef8ed0c006e79d61e1f97a1425baf3c4c49437585f5a71a795dd640039b84f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
sample_images/precomputed/aesthetic_model2_vangogh/image_10.png ADDED

Git LFS Details

  • SHA256: 7c09368b75267fa394ac8ec07ae78caded81f2b8d29b070b1f775d0b6ccdae4a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.1 MB
sample_images/precomputed/aesthetic_model2_vangogh/image_2.png ADDED

Git LFS Details

  • SHA256: 1406e6561b6636db239db19ae099417d370c7b972edcf5296418f9f083e60a2a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.33 MB
sample_images/precomputed/aesthetic_model2_vangogh/image_3.png ADDED

Git LFS Details

  • SHA256: 0a0c88a983e7fc785c6a5ed956492ffbdbaeca89bd9eed392c7892fb6639cf7c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
sample_images/precomputed/aesthetic_model2_vangogh/image_4.png ADDED

Git LFS Details

  • SHA256: aed877ac7f9679737ed92b089878b388bd431db83490de48e38b670438616233
  • Pointer size: 132 Bytes
  • Size of remote file: 1.45 MB
sample_images/precomputed/aesthetic_model2_vangogh/image_5.png ADDED

Git LFS Details

  • SHA256: 07f35c6defb0cecffa02ff2050f814128683110f1039fca7df8f4439fcbcaf37
  • Pointer size: 132 Bytes
  • Size of remote file: 1.57 MB
sample_images/precomputed/aesthetic_model2_vangogh/image_6.png ADDED

Git LFS Details

  • SHA256: 75e504622e5ccaa8b955814028659520baa33c1914d6f1d6fba798dd662a8019
  • Pointer size: 132 Bytes
  • Size of remote file: 1.71 MB
sample_images/precomputed/aesthetic_model2_vangogh/image_7.png ADDED

Git LFS Details

  • SHA256: b8a31ebca13426a52b0c4741624b831a5e0ddc876b43ddc1548dd3d9cf12998e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.85 MB
sample_images/precomputed/aesthetic_model2_vangogh/image_8.png ADDED

Git LFS Details

  • SHA256: 50f003883e6ae9a33803ec496dc61ce4804fa54b413021f43ef0c73e74003cd7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.95 MB
sample_images/precomputed/aesthetic_model2_vangogh/image_9.png ADDED

Git LFS Details

  • SHA256: 06430d50fe132b516a1b06b1ce66091d889c2f1cb9c287e5873efa45090fd46c
  • Pointer size: 132 Bytes
  • Size of remote file: 2.04 MB
sample_images/precomputed/enfield3_winter_snow.png ADDED

Git LFS Details

  • SHA256: 7da9bb49158c845e6f64175ab0c8be765566110f5db94ed06000366c6070f5cd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
sample_images/precomputed/enfield3_winter_snow/image_0.png ADDED

Git LFS Details

  • SHA256: 697b1ad79bfc4706ee49bb45bb4aa4931819f696c2386f60cbfdec164d2b53b6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.45 MB
sample_images/precomputed/enfield3_winter_snow/image_1.png ADDED

Git LFS Details

  • SHA256: cffee97f44ea7b96c662645941737ab267e2f7bb157632a2b6ab356756a1777e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
sample_images/precomputed/enfield3_winter_snow/image_10.png ADDED

Git LFS Details

  • SHA256: a6ea9e053852ad2ce9e37918912c6a30495644ee1e2a14ef254b6f9e801cb869
  • Pointer size: 132 Bytes
  • Size of remote file: 1.3 MB
sample_images/precomputed/enfield3_winter_snow/image_2.png ADDED

Git LFS Details

  • SHA256: b8e8347be2460ac636c920e39c23c487d149491aebafae37d07d6a85d1243d3a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
sample_images/precomputed/enfield3_winter_snow/image_3.png ADDED

Git LFS Details

  • SHA256: 868d5fe662b030e8735d31066de70cc4af44b6bb1ff6ed8abaaa23f82c16dca4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
sample_images/precomputed/enfield3_winter_snow/image_4.png ADDED

Git LFS Details

  • SHA256: da9b5643b1a8b810556108af1aff9bee2f55b076c844fbe9054a51bd213c52b8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
sample_images/precomputed/enfield3_winter_snow/image_5.png ADDED

Git LFS Details

  • SHA256: a3088b90c06640ca9b969eade9d189b5aff41afef8689d4eacaa20e44c3f8223
  • Pointer size: 132 Bytes
  • Size of remote file: 1.23 MB
sample_images/precomputed/enfield3_winter_snow/image_6.png ADDED

Git LFS Details

  • SHA256: c038f494d112fade0382addf8182efcfd6e484ab9e00c1b9979a053d1caca5ac
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
sample_images/precomputed/enfield3_winter_snow/image_7.png ADDED

Git LFS Details

  • SHA256: 55c5c842a64e5556cea8f8006f0d206d969e742ccc89eb1cb41004292dba4798
  • Pointer size: 132 Bytes
  • Size of remote file: 1.23 MB
sample_images/precomputed/enfield3_winter_snow/image_8.png ADDED

Git LFS Details

  • SHA256: 761372cb2ac03280c3c928c2bef56a65ebe029eafdd08a8e1b682c566dcce954
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
sample_images/precomputed/enfield3_winter_snow/image_9.png ADDED

Git LFS Details

  • SHA256: cadb36626fa1354c56218a2cff54a342701388d8378a1d507bfeb892a4b25181
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
sample_images/precomputed/jackson_fluffy.png ADDED

Git LFS Details

  • SHA256: b4090275dfde46631b7127a84ce1eebc2a62ff17b2be09850984f47756d268cb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.14 MB
sample_images/precomputed/jackson_fluffy/image_0.png ADDED

Git LFS Details

  • SHA256: 5e902034f6af4a9abf22c4aff5d800be827641466c9c6b409c173cf690ac8254
  • Pointer size: 132 Bytes
  • Size of remote file: 1.46 MB
sample_images/precomputed/jackson_fluffy/image_1.png ADDED

Git LFS Details

  • SHA256: 6799019b8334d5045a476ecae2b62f0259a92ca667ef48d5a25c85fc68b79fd4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.14 MB
sample_images/precomputed/jackson_fluffy/image_10.png ADDED

Git LFS Details

  • SHA256: 91677e9150c384413ce86e1881cafc5bf7a0bdc1958bc7e86eed236e12ce5c56
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
sample_images/precomputed/jackson_fluffy/image_11.png ADDED

Git LFS Details

  • SHA256: 87de94102b04ad5be63afb7fce11c32e65e752c9f09813fdc4f29ae6abcb271a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
sample_images/precomputed/jackson_fluffy/image_2.png ADDED

Git LFS Details

  • SHA256: 0f6208e7ba3d460ae81f322133f16b3a91a6af0f3e84f5f4892dd68d94f099c1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.1 MB
sample_images/precomputed/jackson_fluffy/image_3.png ADDED

Git LFS Details

  • SHA256: 57476e1327bed4fa83fba4d7ff0ec71e18ab30e51c1e57f7e410dd025b3f2793
  • Pointer size: 132 Bytes
  • Size of remote file: 1.1 MB
sample_images/precomputed/jackson_fluffy/image_4.png ADDED

Git LFS Details

  • SHA256: f86db628089393fb4c8822a7b49a6fe9dcd357d8df7159a6939fb59aac9d8184
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
sample_images/precomputed/jackson_fluffy/image_5.png ADDED

Git LFS Details

  • SHA256: 4ae7fc49b785c36d92f9ac91bd217d020c67ef1bd6712e048cacdcab04c4cd0f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
sample_images/precomputed/jackson_fluffy/image_6.png ADDED

Git LFS Details

  • SHA256: a02076072483204e12b4596df807c5ba212e65f0b05bfb0aa57c4a9dc203ae2d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.14 MB
sample_images/precomputed/jackson_fluffy/image_7.png ADDED

Git LFS Details

  • SHA256: a854d126f2a9411561dac0376530980cffdd094da3408742985026d0e8a0b8ab
  • Pointer size: 132 Bytes
  • Size of remote file: 1.14 MB