Anonymous20250508 commited on
Commit
924a45b
·
verified ·
1 Parent(s): b081447

Upload 84 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +44 -0
  2. DST/config/deepspeed/zero2_config.json +15 -0
  3. DST/config/deepspeed/zero3_config.json +33 -0
  4. DST/datasets/dreambench_style.json +9 -0
  5. DST/dst/dataset/dst.py +79 -0
  6. DST/dst/flux/math.py +31 -0
  7. DST/dst/flux/model.py +208 -0
  8. DST/dst/flux/modules/autoencoder.py +312 -0
  9. DST/dst/flux/modules/conditioner.py +39 -0
  10. DST/dst/flux/modules/layers.py +421 -0
  11. DST/dst/flux/pipeline.py +266 -0
  12. DST/dst/flux/sampling.py +243 -0
  13. DST/dst/flux/util.py +404 -0
  14. DST/dst/utils/convert_yaml_to_args_file.py +21 -0
  15. DST/inference.py +104 -0
  16. DST/output/tower@American Comic_Architecture_Church or mosque_7f69557d-751b-4c7c-9495-5156b2513989_42.jpg +3 -0
  17. DST/output/tower@American Comic_Object_Backpack or bag_938b06c4-92bb-4535-a2cf-6112318d0c0d_42.jpg +3 -0
  18. DST/output/tower@Anime_04c5405f-fcaa-4065-899e-49149e2835e7.jpg +3 -0
  19. DST/output/tower@Anime_Animal_Dog_3d4f8063-1ed9-4601-8c16-6f9f2e77c81b_42.jpg +3 -0
  20. DST/output/tower@Flat Design_Scene_Beach or coast_31be5a9d-5c07-4c60-be0c-6b0034e8244b_42.jpg +3 -0
  21. DST/output/tower@Flat Design_b5f391f3-c7f4-4c21-8f35-a6f33df8eae5.jpg +3 -0
  22. DST/output/tower@Ghibli_19ed1a7f-d7ef-49f6-9f7d-9d5961c385cc.jpg +3 -0
  23. DST/output/tower@Graffiti_Scene_Forest scene_64c1b63c-d094-4f50-bf71-2ff3d0035dd8_42.jpg +3 -0
  24. DST/output/tower@Line Art_3f6a80ef-5fa8-492a-a8b6-1b515e3ebd97.jpg +3 -0
  25. DST/output/tower@Linocut_15502303-0459-4a02-be7e-b5b93ba69c1e.jpg +3 -0
  26. DST/output/tower@Neon_Scene_Beach or coast_c42c6871-6268-41c9-8ec7-a36c999b5acb_42.jpg +3 -0
  27. DST/output/tower@Pixel Art_8b869e57-7345-4f78-8d8b-07a2def7979c.jpg +3 -0
  28. DST/output/tower@Watercolor_e15d75e6-796f-4289-ae2e-a0b04ba1a5ea.jpg +3 -0
  29. DST/readme.md +35 -0
  30. DST/requirements.txt +16 -0
  31. DST/run.sh +10 -0
  32. DST/save/1024_modernart/dit_lora.safetensors +3 -0
  33. DST/save/1024_nga/dit_lora.safetensors +3 -0
  34. DST/test.sh +7 -0
  35. DST/test/cnt/tower.jpg +3 -0
  36. DST/test/cnt_nga/0field.jpeg +0 -0
  37. DST/test/cnt_nga/0rahul-chakraborty-9Wg7qAhGmnU-unsplash.jpg +3 -0
  38. DST/test/cnt_nga/0trip.jpg +3 -0
  39. DST/test/cnt_nga/1mio-ito-DaGIjXNl5oA-unsplash.jpg +3 -0
  40. DST/test/sty/American Comic_Architecture_Church or mosque_7f69557d-751b-4c7c-9495-5156b2513989_42.png +3 -0
  41. DST/test/sty/American Comic_Object_Backpack or bag_938b06c4-92bb-4535-a2cf-6112318d0c0d_42.png +3 -0
  42. DST/test/sty/Anime_04c5405f-fcaa-4065-899e-49149e2835e7.png +3 -0
  43. DST/test/sty/Anime_Animal_Dog_3d4f8063-1ed9-4601-8c16-6f9f2e77c81b_42.png +3 -0
  44. DST/test/sty/Flat Design_Scene_Beach or coast_31be5a9d-5c07-4c60-be0c-6b0034e8244b_42.png +3 -0
  45. DST/test/sty/Flat Design_b5f391f3-c7f4-4c21-8f35-a6f33df8eae5.png +3 -0
  46. DST/test/sty/Ghibli_19ed1a7f-d7ef-49f6-9f7d-9d5961c385cc.png +3 -0
  47. DST/test/sty/Graffiti_Scene_Forest scene_64c1b63c-d094-4f50-bf71-2ff3d0035dd8_42.png +3 -0
  48. DST/test/sty/Line Art_3f6a80ef-5fa8-492a-a8b6-1b515e3ebd97.png +3 -0
  49. DST/test/sty/Linocut_15502303-0459-4a02-be7e-b5b93ba69c1e.png +3 -0
  50. DST/test/sty/Neon_Scene_Beach or coast_c42c6871-6268-41c9-8ec7-a36c999b5acb_42.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,47 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ DST/output/tower@American[[:space:]]Comic_Architecture_Church[[:space:]]or[[:space:]]mosque_7f69557d-751b-4c7c-9495-5156b2513989_42.jpg filter=lfs diff=lfs merge=lfs -text
37
+ DST/output/tower@American[[:space:]]Comic_Object_Backpack[[:space:]]or[[:space:]]bag_938b06c4-92bb-4535-a2cf-6112318d0c0d_42.jpg filter=lfs diff=lfs merge=lfs -text
38
+ DST/output/tower@Anime_04c5405f-fcaa-4065-899e-49149e2835e7.jpg filter=lfs diff=lfs merge=lfs -text
39
+ DST/output/tower@Anime_Animal_Dog_3d4f8063-1ed9-4601-8c16-6f9f2e77c81b_42.jpg filter=lfs diff=lfs merge=lfs -text
40
+ DST/output/tower@Flat[[:space:]]Design_b5f391f3-c7f4-4c21-8f35-a6f33df8eae5.jpg filter=lfs diff=lfs merge=lfs -text
41
+ DST/output/tower@Flat[[:space:]]Design_Scene_Beach[[:space:]]or[[:space:]]coast_31be5a9d-5c07-4c60-be0c-6b0034e8244b_42.jpg filter=lfs diff=lfs merge=lfs -text
42
+ DST/output/tower@Ghibli_19ed1a7f-d7ef-49f6-9f7d-9d5961c385cc.jpg filter=lfs diff=lfs merge=lfs -text
43
+ DST/output/tower@Graffiti_Scene_Forest[[:space:]]scene_64c1b63c-d094-4f50-bf71-2ff3d0035dd8_42.jpg filter=lfs diff=lfs merge=lfs -text
44
+ DST/output/tower@Line[[:space:]]Art_3f6a80ef-5fa8-492a-a8b6-1b515e3ebd97.jpg filter=lfs diff=lfs merge=lfs -text
45
+ DST/output/tower@Linocut_15502303-0459-4a02-be7e-b5b93ba69c1e.jpg filter=lfs diff=lfs merge=lfs -text
46
+ DST/output/tower@Neon_Scene_Beach[[:space:]]or[[:space:]]coast_c42c6871-6268-41c9-8ec7-a36c999b5acb_42.jpg filter=lfs diff=lfs merge=lfs -text
47
+ DST/output/tower@Pixel[[:space:]]Art_8b869e57-7345-4f78-8d8b-07a2def7979c.jpg filter=lfs diff=lfs merge=lfs -text
48
+ DST/output/tower@Watercolor_e15d75e6-796f-4289-ae2e-a0b04ba1a5ea.jpg filter=lfs diff=lfs merge=lfs -text
49
+ DST/test/cnt_nga/0rahul-chakraborty-9Wg7qAhGmnU-unsplash.jpg filter=lfs diff=lfs merge=lfs -text
50
+ DST/test/cnt_nga/0trip.jpg filter=lfs diff=lfs merge=lfs -text
51
+ DST/test/cnt_nga/1mio-ito-DaGIjXNl5oA-unsplash.jpg filter=lfs diff=lfs merge=lfs -text
52
+ DST/test/cnt/tower.jpg filter=lfs diff=lfs merge=lfs -text
53
+ DST/test/sty_nga/1.png filter=lfs diff=lfs merge=lfs -text
54
+ DST/test/sty_nga/11.png filter=lfs diff=lfs merge=lfs -text
55
+ DST/test/sty_nga/2.png filter=lfs diff=lfs merge=lfs -text
56
+ DST/test/sty_nga/3.png filter=lfs diff=lfs merge=lfs -text
57
+ DST/test/sty_nga/4.png filter=lfs diff=lfs merge=lfs -text
58
+ DST/test/sty_nga/5.png filter=lfs diff=lfs merge=lfs -text
59
+ DST/test/sty_nga/6.png filter=lfs diff=lfs merge=lfs -text
60
+ DST/test/sty_nga/7.png filter=lfs diff=lfs merge=lfs -text
61
+ DST/test/sty_nga/8.png filter=lfs diff=lfs merge=lfs -text
62
+ DST/test/sty_nga/9.png filter=lfs diff=lfs merge=lfs -text
63
+ DST/test/sty/American[[:space:]]Comic_Architecture_Church[[:space:]]or[[:space:]]mosque_7f69557d-751b-4c7c-9495-5156b2513989_42.png filter=lfs diff=lfs merge=lfs -text
64
+ DST/test/sty/American[[:space:]]Comic_Object_Backpack[[:space:]]or[[:space:]]bag_938b06c4-92bb-4535-a2cf-6112318d0c0d_42.png filter=lfs diff=lfs merge=lfs -text
65
+ DST/test/sty/Anime_04c5405f-fcaa-4065-899e-49149e2835e7.png filter=lfs diff=lfs merge=lfs -text
66
+ DST/test/sty/Anime_Animal_Dog_3d4f8063-1ed9-4601-8c16-6f9f2e77c81b_42.png filter=lfs diff=lfs merge=lfs -text
67
+ DST/test/sty/Flat[[:space:]]Design_b5f391f3-c7f4-4c21-8f35-a6f33df8eae5.png filter=lfs diff=lfs merge=lfs -text
68
+ DST/test/sty/Flat[[:space:]]Design_Scene_Beach[[:space:]]or[[:space:]]coast_31be5a9d-5c07-4c60-be0c-6b0034e8244b_42.png filter=lfs diff=lfs merge=lfs -text
69
+ DST/test/sty/Ghibli_19ed1a7f-d7ef-49f6-9f7d-9d5961c385cc.png filter=lfs diff=lfs merge=lfs -text
70
+ DST/test/sty/Graffiti_Scene_Forest[[:space:]]scene_64c1b63c-d094-4f50-bf71-2ff3d0035dd8_42.png filter=lfs diff=lfs merge=lfs -text
71
+ DST/test/sty/Line[[:space:]]Art_3f6a80ef-5fa8-492a-a8b6-1b515e3ebd97.png filter=lfs diff=lfs merge=lfs -text
72
+ DST/test/sty/Linocut_15502303-0459-4a02-be7e-b5b93ba69c1e.png filter=lfs diff=lfs merge=lfs -text
73
+ DST/test/sty/Neon_Scene_Beach[[:space:]]or[[:space:]]coast_c42c6871-6268-41c9-8ec7-a36c999b5acb_42.png filter=lfs diff=lfs merge=lfs -text
74
+ DST/test/sty/Pixel[[:space:]]Art_8b869e57-7345-4f78-8d8b-07a2def7979c.png filter=lfs diff=lfs merge=lfs -text
75
+ DST/test/sty/Watercolor_e15d75e6-796f-4289-ae2e-a0b04ba1a5ea.png filter=lfs diff=lfs merge=lfs -text
76
+ DST/train_json/flux2_train_data.json filter=lfs diff=lfs merge=lfs -text
77
+ DST/train_json/merged_all.json filter=lfs diff=lfs merge=lfs -text
78
+ DST/train_json/nga_sft_train.json filter=lfs diff=lfs merge=lfs -text
79
+ DST/train_json/train_impressionism_aug.json filter=lfs diff=lfs merge=lfs -text
DST/config/deepspeed/zero2_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": "auto"
4
+ },
5
+ "zero_optimization": {
6
+ "stage": 2,
7
+ "offload_optimizer": {
8
+ "device": "none"
9
+ },
10
+ "contiguous_gradients": true,
11
+ "overlap_comm": true
12
+ },
13
+ "train_micro_batch_size_per_gpu": 1,
14
+ "gradient_accumulation_steps": "auto"
15
+ }
DST/config/deepspeed/zero3_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+
11
+ "zero_optimization": {
12
+ "stage": 3,
13
+ "offload_optimizer": {
14
+ "device": "cpu",
15
+ "pin_memory": true
16
+ },
17
+ "offload_param": {
18
+ "device": "cpu",
19
+ "pin_memory": true
20
+ },
21
+ "overlap_comm": true,
22
+ "contiguous_gradients": true,
23
+ "reduce_bucket_size": 16777216,
24
+ "stage3_prefetch_bucket_size": 15099494,
25
+ "stage3_param_persistence_threshold": 40960,
26
+ "sub_group_size": 1e9,
27
+ "stage3_max_live_parameters": 1e9,
28
+ "stage3_max_reuse_distance": 1e9,
29
+ "stage3_gather_16bit_weights_on_model_save": true
30
+ },
31
+ "gradient_accumulation_steps": "auto",
32
+ "train_micro_batch_size_per_gpu": 1
33
+ }
DST/datasets/dreambench_style.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "prompt": "",
4
+ "image_paths": [
5
+ "./test/sty/Neon_Scene_Beach or coast_c42c6871-6268-41c9-8ec7-a36c999b5acb_42.png","./test/cnt/bridge.jpg"
6
+ ],
7
+ "image_tgt_path": "./test/cnt/bridge.jpg"
8
+ }
9
+ ]
DST/dst/dataset/dst.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms.functional as TVF
7
+ from PIL import Image
8
+ from torch.utils.data import DataLoader, Dataset
9
+ from torchvision.transforms import Compose, Normalize, ToTensor, Resize
10
+
11
+
12
+ def bucket_images(images: list[torch.Tensor], resolution: int = 512):
13
+
14
+ images = [image for image in images]
15
+ images = torch.stack(images, dim=0)
16
+ return images
17
+
18
+ class FluxPairedDatasetV2(Dataset):
19
+ def __init__(self, json_file: str, resolution: int, resolution_ref: int | None = None):
20
+ super().__init__()
21
+ self.json_file = json_file
22
+ self.resolution = resolution
23
+ self.resolution_ref = resolution_ref if resolution_ref is not None else resolution
24
+ self.image_root = os.path.dirname(json_file)
25
+
26
+ with open(self.json_file, "rt") as f:
27
+ self.data_dicts = json.load(f)
28
+
29
+ self.transform = Compose([
30
+ Resize((1024, 1024)), # 🛡️先resize
31
+ ToTensor(),
32
+ Normalize([0.5], [0.5]),
33
+ ])
34
+
35
+ def __getitem__(self, idx):
36
+ data_dict = self.data_dicts[idx]
37
+ image_paths = [data_dict["image_path"]] if "image_path" in data_dict else data_dict["image_paths"]
38
+ txt = data_dict["prompt"]
39
+ image_tgt_path = data_dict.get("image_tgt_path", None)
40
+
41
+ ref_imgs = [
42
+ Image.open(os.path.join(self.image_root, path)).convert("RGB")
43
+ for path in image_paths
44
+ ]
45
+ ref_imgs = [self.transform(img) for img in ref_imgs]
46
+ img = None
47
+ if image_tgt_path is not None:
48
+ img = Image.open(os.path.join(self.image_root, image_tgt_path)).convert("RGB")
49
+ img = self.transform(img)
50
+
51
+ return {
52
+ "img": img,
53
+ "txt": txt,
54
+ "ref_imgs": ref_imgs,
55
+ }
56
+
57
+ def __len__(self):
58
+ return len(self.data_dicts)
59
+
60
+ def collate_fn(self, batch):
61
+ img = [data["img"] for data in batch]
62
+ txt = [data["txt"] for data in batch]
63
+ ref_imgs = [data["ref_imgs"] for data in batch]
64
+ assert all([len(ref_imgs[0]) == len(ref_imgs[i]) for i in range(len(ref_imgs))])
65
+
66
+ n_ref = len(ref_imgs[0])
67
+
68
+ img = bucket_images(img, self.resolution)
69
+ ref_imgs_new = []
70
+ for i in range(n_ref):
71
+ ref_imgs_i = [refs[i] for refs in ref_imgs]
72
+ ref_imgs_i = bucket_images(ref_imgs_i, self.resolution_ref)
73
+ ref_imgs_new.append(ref_imgs_i)
74
+
75
+ return {
76
+ "txt": txt,
77
+ "img": img,
78
+ "ref_imgs": ref_imgs_new,
79
+ }
DST/dst/flux/math.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from einops import rearrange
4
+ from torch import Tensor
5
+
6
+
7
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
8
+ q, k = apply_rope(q, k, pe)
9
+
10
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
11
+ x = rearrange(x, "B H L D -> B L (H D)")
12
+
13
+ return x
14
+
15
+
16
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
17
+ assert dim % 2 == 0
18
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
19
+ omega = 1.0 / (theta**scale)
20
+ out = torch.einsum("...n,d->...nd", pos, omega)
21
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
22
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
23
+ return out.float()
24
+
25
+
26
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
27
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
28
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
29
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
30
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
31
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
DST/dst/flux/model.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from torch import Tensor, nn
6
+
7
+ from .modules.layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding
8
+
9
+
10
+ @dataclass
11
+ class FluxParams:
12
+ in_channels: int
13
+ vec_in_dim: int
14
+ context_in_dim: int
15
+ hidden_size: int
16
+ mlp_ratio: float
17
+ num_heads: int
18
+ depth: int
19
+ depth_single_blocks: int
20
+ axes_dim: list[int]
21
+ theta: int
22
+ qkv_bias: bool
23
+ guidance_embed: bool
24
+
25
+
26
+ class Flux(nn.Module):
27
+ """
28
+ Transformer model for flow matching on sequences.
29
+ """
30
+ _supports_gradient_checkpointing = True
31
+
32
+ def __init__(self, params: FluxParams):
33
+ super().__init__()
34
+
35
+ self.params = params
36
+ self.in_channels = params.in_channels
37
+ self.out_channels = self.in_channels
38
+ if params.hidden_size % params.num_heads != 0:
39
+ raise ValueError(
40
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
41
+ )
42
+ pe_dim = params.hidden_size // params.num_heads
43
+ if sum(params.axes_dim) != pe_dim:
44
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
45
+ self.hidden_size = params.hidden_size
46
+ self.num_heads = params.num_heads
47
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
48
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
49
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
50
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
51
+ self.guidance_in = (
52
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
53
+ )
54
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
55
+
56
+ self.double_blocks = nn.ModuleList(
57
+ [
58
+ DoubleStreamBlock(
59
+ self.hidden_size,
60
+ self.num_heads,
61
+ mlp_ratio=params.mlp_ratio,
62
+ qkv_bias=params.qkv_bias,
63
+ )
64
+ for _ in range(params.depth)
65
+ ]
66
+ )
67
+
68
+ self.single_blocks = nn.ModuleList(
69
+ [
70
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
71
+ for _ in range(params.depth_single_blocks)
72
+ ]
73
+ )
74
+
75
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
76
+ self.gradient_checkpointing = False
77
+
78
+ def _set_gradient_checkpointing(self, module, value=False):
79
+ if hasattr(module, "gradient_checkpointing"):
80
+ module.gradient_checkpointing = value
81
+
82
+ @property
83
+ def attn_processors(self):
84
+ # set recursively
85
+ processors = {} # type: dict[str, nn.Module]
86
+
87
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
88
+ if hasattr(module, "set_processor"):
89
+ processors[f"{name}.processor"] = module.processor
90
+
91
+ for sub_name, child in module.named_children():
92
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
93
+
94
+ return processors
95
+
96
+ for name, module in self.named_children():
97
+ fn_recursive_add_processors(name, module, processors)
98
+
99
+ return processors
100
+
101
+ def set_attn_processor(self, processor):
102
+ r"""
103
+ Sets the attention processor to use to compute attention.
104
+
105
+ Parameters:
106
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
107
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
108
+ for **all** `Attention` layers.
109
+
110
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
111
+ processor. This is strongly recommended when setting trainable attention processors.
112
+
113
+ """
114
+ count = len(self.attn_processors.keys())
115
+
116
+ if isinstance(processor, dict) and len(processor) != count:
117
+ raise ValueError(
118
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
119
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
120
+ )
121
+
122
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
123
+ if hasattr(module, "set_processor"):
124
+ if not isinstance(processor, dict):
125
+ module.set_processor(processor)
126
+ else:
127
+ module.set_processor(processor.pop(f"{name}.processor"))
128
+
129
+ for sub_name, child in module.named_children():
130
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
131
+
132
+ for name, module in self.named_children():
133
+ fn_recursive_attn_processor(name, module, processor)
134
+
135
+ def forward(
136
+ self,
137
+ img: Tensor,
138
+ img_ids: Tensor,
139
+ txt: Tensor,
140
+ txt_ids: Tensor,
141
+ timesteps: Tensor,
142
+ y: Tensor,
143
+ guidance: Tensor | None = None,
144
+ ref_img: Tensor | None = None,
145
+ ref_img_ids: Tensor | None = None,
146
+ ) -> Tensor:
147
+ if img.ndim != 3 or txt.ndim != 3:
148
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
149
+
150
+ # running on sequences img
151
+ img = self.img_in(img)
152
+ vec = self.time_in(timestep_embedding(timesteps, 256))
153
+ if self.params.guidance_embed:
154
+ if guidance is None:
155
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
156
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
157
+ vec = vec + self.vector_in(y)
158
+ txt = self.txt_in(txt)
159
+
160
+ ids = torch.cat((txt_ids, img_ids), dim=1)
161
+
162
+ # concat ref_img/img
163
+ img_end = img.shape[1]
164
+ if ref_img is not None:
165
+ if isinstance(ref_img, tuple) or isinstance(ref_img, list):
166
+ img_in = [img] + [self.img_in(ref) for ref in ref_img]
167
+ img_ids = [ids] + [ref_ids for ref_ids in ref_img_ids]
168
+ img = torch.cat(img_in, dim=1)
169
+ ids = torch.cat(img_ids, dim=1)
170
+ else:
171
+ img = torch.cat((img, self.img_in(ref_img)), dim=1)
172
+ ids = torch.cat((ids, ref_img_ids), dim=1)
173
+ pe = self.pe_embedder(ids)
174
+
175
+ for index_block, block in enumerate(self.double_blocks):
176
+ if self.training and self.gradient_checkpointing:
177
+ img, txt = torch.utils.checkpoint.checkpoint(
178
+ block,
179
+ img=img,
180
+ txt=txt,
181
+ vec=vec,
182
+ pe=pe,
183
+ use_reentrant=False,
184
+ )
185
+ else:
186
+ img, txt = block(
187
+ img=img,
188
+ txt=txt,
189
+ vec=vec,
190
+ pe=pe
191
+ )
192
+
193
+ img = torch.cat((txt, img), 1)
194
+ for block in self.single_blocks:
195
+ if self.training and self.gradient_checkpointing:
196
+ img = torch.utils.checkpoint.checkpoint(
197
+ block,
198
+ img, vec=vec, pe=pe,
199
+ use_reentrant=False
200
+ )
201
+ else:
202
+ img = block(img, vec=vec, pe=pe)
203
+ img = img[:, txt.shape[1] :, ...]
204
+ # index img
205
+ img = img[:, :img_end, ...]
206
+
207
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
208
+ return img
DST/dst/flux/modules/autoencoder.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor, nn
6
+
7
+
8
+ @dataclass
9
+ class AutoEncoderParams:
10
+ resolution: int
11
+ in_channels: int
12
+ ch: int
13
+ out_ch: int
14
+ ch_mult: list[int]
15
+ num_res_blocks: int
16
+ z_channels: int
17
+ scale_factor: float
18
+ shift_factor: float
19
+
20
+
21
+ def swish(x: Tensor) -> Tensor:
22
+ return x * torch.sigmoid(x)
23
+
24
+
25
+ class AttnBlock(nn.Module):
26
+ def __init__(self, in_channels: int):
27
+ super().__init__()
28
+ self.in_channels = in_channels
29
+
30
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
+ if self.in_channels != self.out_channels:
67
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
+
69
+ def forward(self, x):
70
+ h = x
71
+ h = self.norm1(h)
72
+ h = swish(h)
73
+ h = self.conv1(h)
74
+
75
+ h = self.norm2(h)
76
+ h = swish(h)
77
+ h = self.conv2(h)
78
+
79
+ if self.in_channels != self.out_channels:
80
+ x = self.nin_shortcut(x)
81
+
82
+ return x + h
83
+
84
+
85
+ class Downsample(nn.Module):
86
+ def __init__(self, in_channels: int):
87
+ super().__init__()
88
+ # no asymmetric padding in torch conv, must do it ourselves
89
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
+
91
+ def forward(self, x: Tensor):
92
+ pad = (0, 1, 0, 1)
93
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
94
+ x = self.conv(x)
95
+ return x
96
+
97
+
98
+ class Upsample(nn.Module):
99
+ def __init__(self, in_channels: int):
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
+
103
+ def forward(self, x: Tensor):
104
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Encoder(nn.Module):
110
+ def __init__(
111
+ self,
112
+ resolution: int,
113
+ in_channels: int,
114
+ ch: int,
115
+ ch_mult: list[int],
116
+ num_res_blocks: int,
117
+ z_channels: int,
118
+ ):
119
+ super().__init__()
120
+ self.ch = ch
121
+ self.num_resolutions = len(ch_mult)
122
+ self.num_res_blocks = num_res_blocks
123
+ self.resolution = resolution
124
+ self.in_channels = in_channels
125
+ # downsampling
126
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
+
128
+ curr_res = resolution
129
+ in_ch_mult = (1,) + tuple(ch_mult)
130
+ self.in_ch_mult = in_ch_mult
131
+ self.down = nn.ModuleList()
132
+ block_in = self.ch
133
+ for i_level in range(self.num_resolutions):
134
+ block = nn.ModuleList()
135
+ attn = nn.ModuleList()
136
+ block_in = ch * in_ch_mult[i_level]
137
+ block_out = ch * ch_mult[i_level]
138
+ for _ in range(self.num_res_blocks):
139
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
+ block_in = block_out
141
+ down = nn.Module()
142
+ down.block = block
143
+ down.attn = attn
144
+ if i_level != self.num_resolutions - 1:
145
+ down.downsample = Downsample(block_in)
146
+ curr_res = curr_res // 2
147
+ self.down.append(down)
148
+
149
+ # middle
150
+ self.mid = nn.Module()
151
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
+ self.mid.attn_1 = AttnBlock(block_in)
153
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
+
155
+ # end
156
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ # downsampling
161
+ hs = [self.conv_in(x)]
162
+ for i_level in range(self.num_resolutions):
163
+ for i_block in range(self.num_res_blocks):
164
+ h = self.down[i_level].block[i_block](hs[-1])
165
+ if len(self.down[i_level].attn) > 0:
166
+ h = self.down[i_level].attn[i_block](h)
167
+ hs.append(h)
168
+ if i_level != self.num_resolutions - 1:
169
+ hs.append(self.down[i_level].downsample(hs[-1]))
170
+
171
+ # middle
172
+ h = hs[-1]
173
+ h = self.mid.block_1(h)
174
+ h = self.mid.attn_1(h)
175
+ h = self.mid.block_2(h)
176
+ # end
177
+ h = self.norm_out(h)
178
+ h = swish(h)
179
+ h = self.conv_out(h)
180
+ return h
181
+
182
+
183
+ class Decoder(nn.Module):
184
+ def __init__(
185
+ self,
186
+ ch: int,
187
+ out_ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ in_channels: int,
191
+ resolution: int,
192
+ z_channels: int,
193
+ ):
194
+ super().__init__()
195
+ self.ch = ch
196
+ self.num_resolutions = len(ch_mult)
197
+ self.num_res_blocks = num_res_blocks
198
+ self.resolution = resolution
199
+ self.in_channels = in_channels
200
+ self.ffactor = 2 ** (self.num_resolutions - 1)
201
+
202
+ # compute in_ch_mult, block_in and curr_res at lowest res
203
+ block_in = ch * ch_mult[self.num_resolutions - 1]
204
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
+ self.z_shape = (1, z_channels, curr_res, curr_res)
206
+
207
+ # z to block_in
208
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
+
210
+ # middle
211
+ self.mid = nn.Module()
212
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
+ self.mid.attn_1 = AttnBlock(block_in)
214
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
+
216
+ # upsampling
217
+ self.up = nn.ModuleList()
218
+ for i_level in reversed(range(self.num_resolutions)):
219
+ block = nn.ModuleList()
220
+ attn = nn.ModuleList()
221
+ block_out = ch * ch_mult[i_level]
222
+ for _ in range(self.num_res_blocks + 1):
223
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
+ block_in = block_out
225
+ up = nn.Module()
226
+ up.block = block
227
+ up.attn = attn
228
+ if i_level != 0:
229
+ up.upsample = Upsample(block_in)
230
+ curr_res = curr_res * 2
231
+ self.up.insert(0, up) # prepend to get consistent order
232
+
233
+ # end
234
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
+
237
+ def forward(self, z: Tensor) -> Tensor:
238
+ # z to block_in
239
+ h = self.conv_in(z)
240
+
241
+ # middle
242
+ h = self.mid.block_1(h)
243
+ h = self.mid.attn_1(h)
244
+ h = self.mid.block_2(h)
245
+
246
+ # upsampling
247
+ for i_level in reversed(range(self.num_resolutions)):
248
+ for i_block in range(self.num_res_blocks + 1):
249
+ h = self.up[i_level].block[i_block](h)
250
+ if len(self.up[i_level].attn) > 0:
251
+ h = self.up[i_level].attn[i_block](h)
252
+ if i_level != 0:
253
+ h = self.up[i_level].upsample(h)
254
+
255
+ # end
256
+ h = self.norm_out(h)
257
+ h = swish(h)
258
+ h = self.conv_out(h)
259
+ return h
260
+
261
+
262
+ class DiagonalGaussian(nn.Module):
263
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
264
+ super().__init__()
265
+ self.sample = sample
266
+ self.chunk_dim = chunk_dim
267
+
268
+ def forward(self, z: Tensor) -> Tensor:
269
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270
+ if self.sample:
271
+ std = torch.exp(0.5 * logvar)
272
+ return mean + std * torch.randn_like(mean)
273
+ else:
274
+ return mean
275
+
276
+
277
+ class AutoEncoder(nn.Module):
278
+ def __init__(self, params: AutoEncoderParams):
279
+ super().__init__()
280
+ self.encoder = Encoder(
281
+ resolution=params.resolution,
282
+ in_channels=params.in_channels,
283
+ ch=params.ch,
284
+ ch_mult=params.ch_mult,
285
+ num_res_blocks=params.num_res_blocks,
286
+ z_channels=params.z_channels,
287
+ )
288
+ self.decoder = Decoder(
289
+ resolution=params.resolution,
290
+ in_channels=params.in_channels,
291
+ ch=params.ch,
292
+ out_ch=params.out_ch,
293
+ ch_mult=params.ch_mult,
294
+ num_res_blocks=params.num_res_blocks,
295
+ z_channels=params.z_channels,
296
+ )
297
+ self.reg = DiagonalGaussian()
298
+
299
+ self.scale_factor = params.scale_factor
300
+ self.shift_factor = params.shift_factor
301
+
302
+ def encode(self, x: Tensor) -> Tensor:
303
+ z = self.reg(self.encoder(x))
304
+ z = self.scale_factor * (z - self.shift_factor)
305
+ return z
306
+
307
+ def decode(self, z: Tensor) -> Tensor:
308
+ z = z / self.scale_factor + self.shift_factor
309
+ return self.decoder(z)
310
+
311
+ def forward(self, x: Tensor) -> Tensor:
312
+ return self.decode(self.encode(x))
DST/dst/flux/modules/conditioner.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from torch import Tensor, nn
3
+ from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
4
+ T5Tokenizer)
5
+
6
+
7
+ class HFEmbedder(nn.Module):
8
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
9
+ super().__init__()
10
+ self.is_clip = "clip" in version.lower()
11
+ self.max_length = max_length
12
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
13
+
14
+ if self.is_clip:
15
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
16
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
17
+ else:
18
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
19
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
20
+
21
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
22
+
23
+ def forward(self, text: list[str]) -> Tensor:
24
+ batch_encoding = self.tokenizer(
25
+ text,
26
+ truncation=True,
27
+ max_length=self.max_length,
28
+ return_length=False,
29
+ return_overflowing_tokens=False,
30
+ padding="max_length",
31
+ return_tensors="pt",
32
+ )
33
+
34
+ outputs = self.hf_module(
35
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
36
+ attention_mask=None,
37
+ output_hidden_states=False,
38
+ )
39
+ return outputs[self.output_key]
DST/dst/flux/modules/layers.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import Tensor, nn
8
+
9
+ from ..math import attention, rope
10
+ import torch.nn.functional as F
11
+
12
+ class EmbedND(nn.Module):
13
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.theta = theta
17
+ self.axes_dim = axes_dim
18
+
19
+ def forward(self, ids: Tensor) -> Tensor:
20
+ n_axes = ids.shape[-1]
21
+ emb = torch.cat(
22
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
23
+ dim=-3,
24
+ )
25
+
26
+ return emb.unsqueeze(1)
27
+
28
+
29
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
30
+ """
31
+ Create sinusoidal timestep embeddings.
32
+ :param t: a 1-D Tensor of N indices, one per batch element.
33
+ These may be fractional.
34
+ :param dim: the dimension of the output.
35
+ :param max_period: controls the minimum frequency of the embeddings.
36
+ :return: an (N, D) Tensor of positional embeddings.
37
+ """
38
+ t = time_factor * t
39
+ half = dim // 2
40
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
41
+ t.device
42
+ )
43
+
44
+ args = t[:, None].float() * freqs[None]
45
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
46
+ if dim % 2:
47
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
48
+ if torch.is_floating_point(t):
49
+ embedding = embedding.to(t)
50
+ return embedding
51
+
52
+
53
+ class MLPEmbedder(nn.Module):
54
+ def __init__(self, in_dim: int, hidden_dim: int):
55
+ super().__init__()
56
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
57
+ self.silu = nn.SiLU()
58
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
59
+
60
+ def forward(self, x: Tensor) -> Tensor:
61
+ return self.out_layer(self.silu(self.in_layer(x)))
62
+
63
+
64
+ class RMSNorm(torch.nn.Module):
65
+ def __init__(self, dim: int):
66
+ super().__init__()
67
+ self.scale = nn.Parameter(torch.ones(dim))
68
+
69
+ def forward(self, x: Tensor):
70
+ x_dtype = x.dtype
71
+ x = x.float()
72
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
73
+ return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
74
+
75
+
76
+ class QKNorm(torch.nn.Module):
77
+ def __init__(self, dim: int):
78
+ super().__init__()
79
+ self.query_norm = RMSNorm(dim)
80
+ self.key_norm = RMSNorm(dim)
81
+
82
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
83
+ q = self.query_norm(q)
84
+ k = self.key_norm(k)
85
+ return q.to(v), k.to(v)
86
+
87
+ class LoRALinearLayer(nn.Module):
88
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
89
+ super().__init__()
90
+
91
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
92
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
93
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
94
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
95
+ self.network_alpha = network_alpha
96
+ self.rank = rank
97
+
98
+ nn.init.normal_(self.down.weight, std=1 / rank)
99
+ nn.init.zeros_(self.up.weight)
100
+
101
+ def forward(self, hidden_states):
102
+ orig_dtype = hidden_states.dtype
103
+ dtype = self.down.weight.dtype
104
+
105
+ down_hidden_states = self.down(hidden_states.to(dtype))
106
+ up_hidden_states = self.up(down_hidden_states)
107
+
108
+ if self.network_alpha is not None:
109
+ up_hidden_states *= self.network_alpha / self.rank
110
+
111
+ return up_hidden_states.to(orig_dtype)
112
+
113
+ class FLuxSelfAttnProcessor:
114
+ def __call__(self, attn, x, pe, **attention_kwargs):
115
+ qkv = attn.qkv(x)
116
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
117
+ q, k = attn.norm(q, k, v)
118
+ x = attention(q, k, v, pe=pe)
119
+ x = attn.proj(x)
120
+ return x
121
+
122
+ class LoraFluxAttnProcessor(nn.Module):
123
+
124
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
125
+ super().__init__()
126
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
127
+ self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
128
+ self.lora_weight = lora_weight
129
+
130
+
131
+ def __call__(self, attn, x, pe, **attention_kwargs):
132
+ qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
133
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
134
+ q, k = attn.norm(q, k, v)
135
+ x = attention(q, k, v, pe=pe)
136
+ x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
137
+ return x
138
+
139
+ class SelfAttention(nn.Module):
140
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
141
+ super().__init__()
142
+ self.num_heads = num_heads
143
+ head_dim = dim // num_heads
144
+
145
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
146
+ self.norm = QKNorm(head_dim)
147
+ self.proj = nn.Linear(dim, dim)
148
+ def forward():
149
+ pass
150
+
151
+
152
+ @dataclass
153
+ class ModulationOut:
154
+ shift: Tensor
155
+ scale: Tensor
156
+ gate: Tensor
157
+
158
+
159
+ class Modulation(nn.Module):
160
+ def __init__(self, dim: int, double: bool):
161
+ super().__init__()
162
+ self.is_double = double
163
+ self.multiplier = 6 if double else 3
164
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
165
+
166
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
167
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
168
+
169
+ return (
170
+ ModulationOut(*out[:3]),
171
+ ModulationOut(*out[3:]) if self.is_double else None,
172
+ )
173
+
174
+ class DoubleStreamBlockLoraProcessor(nn.Module):
175
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
176
+ super().__init__()
177
+ self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
178
+ self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
179
+ self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
180
+ self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
181
+ self.lora_weight = lora_weight
182
+
183
+ def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
184
+ img_mod1, img_mod2 = attn.img_mod(vec)
185
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
186
+
187
+ # prepare image for attention
188
+ img_modulated = attn.img_norm1(img)
189
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
190
+ img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
191
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
192
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
193
+
194
+ # prepare txt for attention
195
+ txt_modulated = attn.txt_norm1(txt)
196
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
197
+ txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
198
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
199
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
200
+
201
+ # run actual attention
202
+ q = torch.cat((txt_q, img_q), dim=2)
203
+ k = torch.cat((txt_k, img_k), dim=2)
204
+ v = torch.cat((txt_v, img_v), dim=2)
205
+
206
+ attn1 = attention(q, k, v, pe=pe)
207
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
208
+
209
+ # calculate the img bloks
210
+ img = img + img_mod1.gate * (attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight)
211
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
212
+
213
+ # calculate the txt bloks
214
+ txt = txt + txt_mod1.gate * (attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight)
215
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
216
+ return img, txt
217
+
218
+ class DoubleStreamBlockProcessor:
219
+ def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
220
+ img_mod1, img_mod2 = attn.img_mod(vec)
221
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
222
+
223
+ # prepare image for attention
224
+ img_modulated = attn.img_norm1(img)
225
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
226
+ img_qkv = attn.img_attn.qkv(img_modulated)
227
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
228
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
229
+
230
+ # prepare txt for attention
231
+ txt_modulated = attn.txt_norm1(txt)
232
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
233
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
234
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
235
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
236
+
237
+ # run actual attention
238
+ q = torch.cat((txt_q, img_q), dim=2)
239
+ k = torch.cat((txt_k, img_k), dim=2)
240
+ v = torch.cat((txt_v, img_v), dim=2)
241
+
242
+ attn1 = attention(q, k, v, pe=pe)
243
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
244
+
245
+ # calculate the img bloks
246
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
247
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
248
+
249
+ # calculate the txt bloks
250
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
251
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
252
+ return img, txt
253
+
254
+ class DoubleStreamBlock(nn.Module):
255
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
256
+ super().__init__()
257
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
258
+ self.num_heads = num_heads
259
+ self.hidden_size = hidden_size
260
+ self.head_dim = hidden_size // num_heads
261
+
262
+ self.img_mod = Modulation(hidden_size, double=True)
263
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
264
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
265
+
266
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
267
+ self.img_mlp = nn.Sequential(
268
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
269
+ nn.GELU(approximate="tanh"),
270
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
271
+ )
272
+
273
+ self.txt_mod = Modulation(hidden_size, double=True)
274
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
275
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
276
+
277
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
278
+ self.txt_mlp = nn.Sequential(
279
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
280
+ nn.GELU(approximate="tanh"),
281
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
282
+ )
283
+ processor = DoubleStreamBlockProcessor()
284
+ self.set_processor(processor)
285
+
286
+ def set_processor(self, processor) -> None:
287
+ self.processor = processor
288
+
289
+ def get_processor(self):
290
+ return self.processor
291
+
292
+ def forward(
293
+ self,
294
+ img: Tensor,
295
+ txt: Tensor,
296
+ vec: Tensor,
297
+ pe: Tensor,
298
+ image_proj: Tensor = None,
299
+ ip_scale: float =1.0,
300
+ ) -> tuple[Tensor, Tensor]:
301
+ if image_proj is None:
302
+ return self.processor(self, img, txt, vec, pe)
303
+ else:
304
+ return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
305
+
306
+
307
+ class SingleStreamBlockLoraProcessor(nn.Module):
308
+ def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1):
309
+ super().__init__()
310
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
311
+ self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha)
312
+ self.lora_weight = lora_weight
313
+
314
+ def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
315
+
316
+ mod, _ = attn.modulation(vec)
317
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
318
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
319
+ qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
320
+
321
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
322
+ q, k = attn.norm(q, k, v)
323
+
324
+ # compute attention
325
+ attn_1 = attention(q, k, v, pe=pe)
326
+
327
+ # compute activation in mlp stream, cat again and run second linear layer
328
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
329
+ output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight
330
+ output = x + mod.gate * output
331
+ return output
332
+
333
+
334
+ class SingleStreamBlockProcessor:
335
+ def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, **attention_kwargs) -> Tensor:
336
+
337
+ mod, _ = attn.modulation(vec)
338
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
339
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
340
+
341
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
342
+ q, k = attn.norm(q, k, v)
343
+
344
+ # compute attention
345
+ attn_1 = attention(q, k, v, pe=pe)
346
+
347
+ # compute activation in mlp stream, cat again and run second linear layer
348
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
349
+ output = x + mod.gate * output
350
+ return output
351
+
352
+ class SingleStreamBlock(nn.Module):
353
+ """
354
+ A DiT block with parallel linear layers as described in
355
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
356
+ """
357
+
358
+ def __init__(
359
+ self,
360
+ hidden_size: int,
361
+ num_heads: int,
362
+ mlp_ratio: float = 4.0,
363
+ qk_scale: float | None = None,
364
+ ):
365
+ super().__init__()
366
+ self.hidden_dim = hidden_size
367
+ self.num_heads = num_heads
368
+ self.head_dim = hidden_size // num_heads
369
+ self.scale = qk_scale or self.head_dim**-0.5
370
+
371
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
372
+ # qkv and mlp_in
373
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
374
+ # proj and mlp_out
375
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
376
+
377
+ self.norm = QKNorm(self.head_dim)
378
+
379
+ self.hidden_size = hidden_size
380
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
381
+
382
+ self.mlp_act = nn.GELU(approximate="tanh")
383
+ self.modulation = Modulation(hidden_size, double=False)
384
+
385
+ processor = SingleStreamBlockProcessor()
386
+ self.set_processor(processor)
387
+
388
+
389
+ def set_processor(self, processor) -> None:
390
+ self.processor = processor
391
+
392
+ def get_processor(self):
393
+ return self.processor
394
+
395
+ def forward(
396
+ self,
397
+ x: Tensor,
398
+ vec: Tensor,
399
+ pe: Tensor,
400
+ image_proj: Tensor | None = None,
401
+ ip_scale: float = 1.0,
402
+ ) -> Tensor:
403
+ if image_proj is None:
404
+ return self.processor(self, x, vec, pe)
405
+ else:
406
+ return self.processor(self, x, vec, pe, image_proj, ip_scale)
407
+
408
+
409
+
410
+ class LastLayer(nn.Module):
411
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
412
+ super().__init__()
413
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
414
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
415
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
416
+
417
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
418
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
419
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
420
+ x = self.linear(x)
421
+ return x
DST/dst/flux/pipeline.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+ from typing import Literal
5
+
6
+ import torch
7
+ from einops import rearrange
8
+ from PIL import ExifTags, Image
9
+ import torchvision.transforms.functional as TVF
10
+
11
+ from dst.flux.modules.layers import (
12
+ DoubleStreamBlockLoraProcessor,
13
+ DoubleStreamBlockProcessor,
14
+ SingleStreamBlockLoraProcessor,
15
+ SingleStreamBlockProcessor,
16
+ )
17
+ from dst.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack
18
+ from dst.flux.util import (
19
+ get_lora_rank,
20
+ load_ae,
21
+ load_checkpoint,
22
+ load_clip,
23
+ load_flow_model,
24
+ load_flow_model_only_lora,
25
+ load_flow_model_quintized,
26
+ load_t5,
27
+ )
28
+
29
+
30
+ def find_nearest_scale(image_h, image_w, predefined_scales):
31
+ """
32
+ 根据图片的高度和宽度,找到最近的预定义尺度。
33
+
34
+ :param image_h: 图片的高度
35
+ :param image_w: 图片的宽度
36
+ :param predefined_scales: 预定义尺度列表 [(h1, w1), (h2, w2), ...]
37
+ :return: 最近的预定义尺度 (h, w)
38
+ """
39
+ # 计算输入图片的长宽比
40
+ image_ratio = image_h / image_w
41
+
42
+ # 初始化变量以存储最小差异和最近的尺度
43
+ min_diff = float('inf')
44
+ nearest_scale = None
45
+
46
+ # 遍历所有预定义尺度,找到与输入图片长宽比最接近的尺度
47
+ for scale_h, scale_w in predefined_scales:
48
+ predefined_ratio = scale_h / scale_w
49
+ diff = abs(predefined_ratio - image_ratio)
50
+
51
+ if diff < min_diff:
52
+ min_diff = diff
53
+ nearest_scale = (scale_h, scale_w)
54
+
55
+ return nearest_scale
56
+
57
+ def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
58
+ # 获取原始图像的宽度和高度
59
+ image_w, image_h = raw_image.size
60
+
61
+ # 计算长边和短边
62
+ if image_w >= image_h:
63
+ new_w = long_size
64
+ new_h = int((long_size / image_w) * image_h)
65
+ else:
66
+ new_h = long_size
67
+ new_w = int((long_size / image_h) * image_w)
68
+
69
+ # 按新的宽高进行等比例缩放
70
+ raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
71
+ target_w = new_w // 16 * 16
72
+ target_h = new_h // 16 * 16
73
+
74
+ # 计算裁剪的起始坐标以实现中心裁剪
75
+ left = (new_w - target_w) // 2
76
+ top = (new_h - target_h) // 2
77
+ right = left + target_w
78
+ bottom = top + target_h
79
+
80
+ # 进行中心裁剪
81
+ raw_image = raw_image.crop((left, top, right, bottom))
82
+
83
+ # 转换为 RGB 模式
84
+ raw_image = raw_image.convert("RGB")
85
+ return raw_image
86
+
87
+ class DSTPipeline:
88
+ def __init__(
89
+ self,
90
+ model_type: str,
91
+ device: torch.device,
92
+ offload: bool = False,
93
+ only_lora: bool = False,
94
+ lora_rank: int = 16
95
+ ):
96
+ self.device = device
97
+ self.offload = offload
98
+ self.model_type = model_type
99
+
100
+ self.clip = load_clip(self.device)
101
+ self.t5 = load_t5(self.device, max_length=512)
102
+ self.ae = load_ae(model_type, device="cpu" if offload else self.device)
103
+ self.use_fp8 = "fp8" in model_type
104
+
105
+ if only_lora:
106
+ self.model = load_flow_model_only_lora(
107
+ model_type,
108
+ device="cpu" if offload else self.device,
109
+ lora_rank=lora_rank,
110
+ use_fp8=self.use_fp8
111
+ )
112
+ else:
113
+ self.model = load_flow_model(model_type, device="cpu" if offload else self.device)
114
+
115
+
116
+ def load_ckpt(self, ckpt_path):
117
+ if ckpt_path is not None:
118
+ from safetensors.torch import load_file as load_sft
119
+ print("Loading checkpoint to replace old keys")
120
+ # load_sft doesn't support torch.device
121
+ if ckpt_path.endswith('safetensors'):
122
+ sd = load_sft(ckpt_path, device='cpu')
123
+ missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
124
+ else:
125
+ dit_state = torch.load(ckpt_path, map_location='cpu')
126
+ sd = {}
127
+ for k in dit_state.keys():
128
+ sd[k.replace('module.','')] = dit_state[k]
129
+ missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
130
+ self.model.to(str(self.device))
131
+ print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}")
132
+
133
+ def set_lora(self, local_path: str = None, repo_id: str = None,
134
+ name: str = None, lora_weight: int = 0.7):
135
+ checkpoint = load_checkpoint(local_path, repo_id, name)
136
+ self.update_model_with_lora(checkpoint, lora_weight)
137
+
138
+ def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7):
139
+ checkpoint = load_checkpoint(
140
+ None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
141
+ )
142
+ self.update_model_with_lora(checkpoint, lora_weight)
143
+
144
+ def update_model_with_lora(self, checkpoint, lora_weight):
145
+ rank = get_lora_rank(checkpoint)
146
+ lora_attn_procs = {}
147
+
148
+ for name, _ in self.model.attn_processors.items():
149
+ lora_state_dict = {}
150
+ for k in checkpoint.keys():
151
+ if name in k:
152
+ lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight
153
+
154
+ if len(lora_state_dict):
155
+ if name.startswith("single_blocks"):
156
+ lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank)
157
+ else:
158
+ lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
159
+ lora_attn_procs[name].load_state_dict(lora_state_dict)
160
+ lora_attn_procs[name].to(self.device)
161
+ else:
162
+ if name.startswith("single_blocks"):
163
+ lora_attn_procs[name] = SingleStreamBlockProcessor()
164
+ else:
165
+ lora_attn_procs[name] = DoubleStreamBlockProcessor()
166
+
167
+ self.model.set_attn_processor(lora_attn_procs)
168
+
169
+
170
+ def __call__(
171
+ self,
172
+ prompt: str,
173
+ width: int = 512,
174
+ height: int = 512,
175
+ guidance: float = 4,
176
+ num_steps: int = 50,
177
+ seed: int = 123456789,
178
+ **kwargs
179
+ ):
180
+ width = 16 * (width // 16)
181
+ height = 16 * (height // 16)
182
+
183
+ device_type = self.device if isinstance(self.device, str) else self.device.type
184
+ if device_type == "mps":
185
+ device_type = "cpu" # for support macos mps
186
+ with torch.autocast(enabled=self.use_fp8, device_type=device_type, dtype=torch.bfloat16):
187
+ return self.forward(
188
+ prompt,
189
+ width,
190
+ height,
191
+ guidance,
192
+ num_steps,
193
+ seed,
194
+ **kwargs
195
+ )
196
+
197
+
198
+
199
+ @torch.inference_mode
200
+ def forward(
201
+ self,
202
+ prompt: str,
203
+ width: int,
204
+ height: int,
205
+ guidance: float,
206
+ num_steps: int,
207
+ seed: int,
208
+ ref_imgs: list[Image.Image] | None = None,
209
+ pe: Literal['d', 'h', 'w', 'o'] = 'd',
210
+ ):
211
+ x = get_noise(
212
+ 1, height, width, device=self.device,
213
+ dtype=torch.bfloat16, seed=seed
214
+ )
215
+ timesteps = get_schedule(
216
+ num_steps,
217
+ (width // 8) * (height // 8) // (16 * 16),
218
+ shift=True,
219
+ )
220
+ if self.offload:
221
+ self.ae.encoder = self.ae.encoder.to(self.device)
222
+ x_1_refs = [
223
+ self.ae.encode(
224
+ (TVF.to_tensor(ref_img) * 2.0 - 1.0)
225
+ .unsqueeze(0).to(self.device, torch.float32)
226
+ ).to(torch.bfloat16)
227
+ for ref_img in ref_imgs
228
+ ]
229
+
230
+ if self.offload:
231
+ self.offload_model_to_cpu(self.ae.encoder)
232
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
233
+ inp_cond = prepare_multi_ip(
234
+ t5=self.t5, clip=self.clip,
235
+ img=x,
236
+ prompt=prompt, ref_imgs=x_1_refs, pe=pe
237
+ )
238
+
239
+ if self.offload:
240
+ self.offload_model_to_cpu(self.t5, self.clip)
241
+ self.model = self.model.to(self.device)
242
+
243
+ x = denoise(
244
+ self.model,
245
+ **inp_cond,
246
+ timesteps=timesteps,
247
+ guidance=guidance,
248
+ )
249
+
250
+ if self.offload:
251
+ self.offload_model_to_cpu(self.model)
252
+ self.ae.decoder.to(x.device)
253
+ x = unpack(x.float(), height, width)
254
+ x = self.ae.decode(x)
255
+ self.offload_model_to_cpu(self.ae.decoder)
256
+
257
+ x1 = x.clamp(-1, 1)
258
+ x1 = rearrange(x1[-1], "c h w -> h w c")
259
+ output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
260
+ return output_img
261
+
262
+ def offload_model_to_cpu(self, *models):
263
+ if not self.offload: return
264
+ for model in models:
265
+ model.cpu()
266
+ torch.cuda.empty_cache()
DST/dst/flux/sampling.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from typing import Literal
4
+
5
+ import torch
6
+ from einops import rearrange, repeat
7
+ from torch import Tensor
8
+ from tqdm import tqdm
9
+
10
+ from .model import Flux
11
+ from .modules.conditioner import HFEmbedder
12
+
13
+
14
+ def get_noise(
15
+ num_samples: int,
16
+ height: int,
17
+ width: int,
18
+ device: torch.device,
19
+ dtype: torch.dtype,
20
+ seed: int,
21
+ ):
22
+ return torch.randn(
23
+ num_samples,
24
+ 16,
25
+ # allow for packing
26
+ 2 * math.ceil(height / 16),
27
+ 2 * math.ceil(width / 16),
28
+ device=device,
29
+ dtype=dtype,
30
+ generator=torch.Generator(device=device).manual_seed(seed),
31
+ )
32
+
33
+
34
+ def prepare(
35
+ t5: HFEmbedder,
36
+ clip: HFEmbedder,
37
+ img: Tensor,
38
+ prompt: str | list[str],
39
+ ref_img: None | Tensor=None,
40
+ pe: Literal['d', 'h', 'w', 'o'] ='d'
41
+ ) -> dict[str, Tensor]:
42
+ assert pe in ['d', 'h', 'w', 'o']
43
+ bs, c, h, w = img.shape
44
+ if bs == 1 and not isinstance(prompt, str):
45
+ bs = len(prompt)
46
+
47
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
48
+ if img.shape[0] == 1 and bs > 1:
49
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
50
+
51
+ img_ids = torch.zeros(h // 2, w // 2, 3)
52
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
53
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
54
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
55
+
56
+ if ref_img is not None:
57
+ _, _, ref_h, ref_w = ref_img.shape
58
+ ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
59
+ if ref_img.shape[0] == 1 and bs > 1:
60
+ ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
61
+ ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
62
+ # img id分别在宽高偏移各自最大值
63
+ h_offset = h // 2 if pe in {'d', 'h'} else 0
64
+ w_offset = w // 2 if pe in {'d', 'w'} else 0
65
+ ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + h_offset
66
+ ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + w_offset
67
+ ref_img_ids = repeat(ref_img_ids, "h w c -> b (h w) c", b=bs)
68
+
69
+ if isinstance(prompt, str):
70
+ prompt = [prompt]
71
+ txt = t5(prompt)
72
+ if txt.shape[0] == 1 and bs > 1:
73
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
74
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
75
+
76
+ vec = clip(prompt)
77
+ if vec.shape[0] == 1 and bs > 1:
78
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
79
+
80
+ if ref_img is not None:
81
+ return {
82
+ "img": img,
83
+ "img_ids": img_ids.to(img.device),
84
+ "ref_img": ref_img,
85
+ "ref_img_ids": ref_img_ids.to(img.device),
86
+ "txt": txt.to(img.device),
87
+ "txt_ids": txt_ids.to(img.device),
88
+ "vec": vec.to(img.device),
89
+ }
90
+ else:
91
+ return {
92
+ "img": img,
93
+ "img_ids": img_ids.to(img.device),
94
+ "txt": txt.to(img.device),
95
+ "txt_ids": txt_ids.to(img.device),
96
+ "vec": vec.to(img.device),
97
+ }
98
+
99
+ def prepare_multi_ip(
100
+ t5: HFEmbedder,
101
+ clip: HFEmbedder,
102
+ img: Tensor,
103
+ prompt: str | list[str],
104
+ ref_imgs: list[Tensor] | None = None,
105
+ pe: Literal['d', 'h', 'w', 'o'] = 'd'
106
+ ) -> dict[str, Tensor]:
107
+ assert pe in ['d', 'h', 'w', 'o']
108
+ bs, c, h, w = img.shape
109
+ if bs == 1 and not isinstance(prompt, str):
110
+ bs = len(prompt)
111
+
112
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
113
+ if img.shape[0] == 1 and bs > 1:
114
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
115
+
116
+ img_ids = torch.zeros(h // 2, w // 2, 3)
117
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
118
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
119
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
120
+
121
+ ref_img_ids = []
122
+ ref_imgs_list = []
123
+ pe_shift_w, pe_shift_h = w // 2, h // 2
124
+ for ref_img in ref_imgs:
125
+ _, _, ref_h1, ref_w1 = ref_img.shape
126
+ ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
127
+ if ref_img.shape[0] == 1 and bs > 1:
128
+ ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
129
+ ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3)
130
+ # img id分别在宽高偏移各自最大值
131
+ h_offset = pe_shift_h if pe in {'d', 'h'} else 0
132
+ w_offset = pe_shift_w if pe in {'d', 'w'} else 0
133
+ ref_img_ids1[..., 1] = ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset
134
+ ref_img_ids1[..., 2] = ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset
135
+ ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs)
136
+ ref_img_ids.append(ref_img_ids1)
137
+ ref_imgs_list.append(ref_img)
138
+
139
+ # 更新pe shift
140
+ pe_shift_h += ref_h1 // 2
141
+ pe_shift_w += ref_w1 // 2
142
+
143
+ if isinstance(prompt, str):
144
+ prompt = [prompt]
145
+ txt = t5(prompt)
146
+ if txt.shape[0] == 1 and bs > 1:
147
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
148
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
149
+
150
+ vec = clip(prompt)
151
+ if vec.shape[0] == 1 and bs > 1:
152
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
153
+
154
+
155
+
156
+ return {
157
+ "img": img,
158
+ "img_ids": img_ids.to(img.device),
159
+ "ref_img": tuple(ref_imgs_list),
160
+ "ref_img_ids": [ref_img_id.to(img.device) for ref_img_id in ref_img_ids],
161
+ "txt": txt.to(img.device),
162
+ "txt_ids": txt_ids.to(img.device),
163
+ "vec": vec.to(img.device),
164
+ }
165
+
166
+
167
+
168
+
169
+ def time_shift(mu: float, sigma: float, t: Tensor):
170
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
171
+
172
+
173
+ def get_lin_function(
174
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
175
+ ):
176
+ m = (y2 - y1) / (x2 - x1)
177
+ b = y1 - m * x1
178
+ return lambda x: m * x + b
179
+
180
+
181
+ def get_schedule(
182
+ num_steps: int,
183
+ image_seq_len: int,
184
+ base_shift: float = 0.5,
185
+ max_shift: float = 1.15,
186
+ shift: bool = True,
187
+ ) -> list[float]:
188
+ # extra step for zero
189
+ timesteps = torch.linspace(1, 0, num_steps + 1)
190
+
191
+ # shifting the schedule to favor high timesteps for higher signal images
192
+ if shift:
193
+ # eastimate mu based on linear estimation between two points
194
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
195
+ timesteps = time_shift(mu, 1.0, timesteps)
196
+
197
+ return timesteps.tolist()
198
+
199
+
200
+ def denoise(
201
+ model: Flux,
202
+ # model input
203
+ img: Tensor,
204
+ img_ids: Tensor,
205
+ txt: Tensor,
206
+ txt_ids: Tensor,
207
+ vec: Tensor,
208
+ # sampling parameters
209
+ timesteps: list[float],
210
+ guidance: float = 4.0,
211
+ ref_img: Tensor=None,
212
+ ref_img_ids: Tensor=None,
213
+ ):
214
+ i = 0
215
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
216
+ for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1):
217
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
218
+
219
+ pred = model(
220
+ img=img,
221
+ img_ids=img_ids,
222
+ ref_img=ref_img,
223
+ ref_img_ids=ref_img_ids,
224
+ txt=txt,
225
+ txt_ids=txt_ids,
226
+ y=vec,
227
+ timesteps=t_vec,
228
+ guidance=guidance_vec
229
+ )
230
+ img = img + (t_prev - t_curr) * pred
231
+ i += 1
232
+ return img
233
+
234
+
235
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
236
+ return rearrange(
237
+ x,
238
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
239
+ h=math.ceil(height / 16),
240
+ w=math.ceil(width / 16),
241
+ ph=2,
242
+ pw=2,
243
+ )
DST/dst/flux/util.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+
6
+ import torch
7
+ import json
8
+ import numpy as np
9
+ from huggingface_hub import hf_hub_download
10
+ from safetensors import safe_open
11
+ from safetensors.torch import load_file as load_sft
12
+
13
+ from .model import Flux, FluxParams
14
+ from .modules.autoencoder import AutoEncoder, AutoEncoderParams
15
+ from .modules.conditioner import HFEmbedder
16
+
17
+ import re
18
+ from dst.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor
19
+ def load_model(ckpt, device='cpu'):
20
+ if ckpt.endswith('safetensors'):
21
+ from safetensors import safe_open
22
+ pl_sd = {}
23
+ with safe_open(ckpt, framework="pt", device=device) as f:
24
+ for k in f.keys():
25
+ pl_sd[k] = f.get_tensor(k)
26
+ else:
27
+ pl_sd = torch.load(ckpt, map_location=device)
28
+ return pl_sd
29
+
30
+ def load_safetensors(path):
31
+ tensors = {}
32
+ with safe_open(path, framework="pt", device="cpu") as f:
33
+ for key in f.keys():
34
+ tensors[key] = f.get_tensor(key)
35
+ return tensors
36
+
37
+ def get_lora_rank(checkpoint):
38
+ for k in checkpoint.keys():
39
+ if k.endswith(".down.weight"):
40
+ return checkpoint[k].shape[0]
41
+
42
+ def load_checkpoint(local_path, repo_id, name):
43
+ if local_path is not None:
44
+ if '.safetensors' in local_path:
45
+ print(f"Loading .safetensors checkpoint from {local_path}")
46
+ checkpoint = load_safetensors(local_path)
47
+ else:
48
+ print(f"Loading checkpoint from {local_path}")
49
+ checkpoint = torch.load(local_path, map_location='cpu')
50
+ elif repo_id is not None and name is not None:
51
+ print(f"Loading checkpoint {name} from repo id {repo_id}")
52
+ checkpoint = load_from_repo_id(repo_id, name)
53
+ else:
54
+ raise ValueError(
55
+ "LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
56
+ )
57
+ return checkpoint
58
+
59
+
60
+ def c_crop(image):
61
+ width, height = image.size
62
+ new_size = min(width, height)
63
+ left = (width - new_size) / 2
64
+ top = (height - new_size) / 2
65
+ right = (width + new_size) / 2
66
+ bottom = (height + new_size) / 2
67
+ return image.crop((left, top, right, bottom))
68
+
69
+ def pad64(x):
70
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
71
+
72
+ def HWC3(x):
73
+ assert x.dtype == np.uint8
74
+ if x.ndim == 2:
75
+ x = x[:, :, None]
76
+ assert x.ndim == 3
77
+ H, W, C = x.shape
78
+ assert C == 1 or C == 3 or C == 4
79
+ if C == 3:
80
+ return x
81
+ if C == 1:
82
+ return np.concatenate([x, x, x], axis=2)
83
+ if C == 4:
84
+ color = x[:, :, 0:3].astype(np.float32)
85
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
86
+ y = color * alpha + 255.0 * (1.0 - alpha)
87
+ y = y.clip(0, 255).astype(np.uint8)
88
+ return y
89
+
90
+ @dataclass
91
+ class ModelSpec:
92
+ params: FluxParams
93
+ ae_params: AutoEncoderParams
94
+ ckpt_path: str | None
95
+ ae_path: str | None
96
+ repo_id: str | None
97
+ repo_flow: str | None
98
+ repo_ae: str | None
99
+ repo_id_ae: str | None
100
+
101
+
102
+ configs = {
103
+ "flux-dev": ModelSpec(
104
+ repo_id="/data1/huggingface_ckpts/FLUX.1-dev",
105
+ repo_id_ae="/data1/huggingface_ckpts/FLUX.1-dev",
106
+ repo_flow="flux1-dev.safetensors",
107
+ repo_ae="ae.safetensors",
108
+ ckpt_path=os.getenv("FLUX_DEV"),
109
+ params=FluxParams(
110
+ in_channels=64,
111
+ vec_in_dim=768,
112
+ context_in_dim=4096,
113
+ hidden_size=3072,
114
+ mlp_ratio=4.0,
115
+ num_heads=24,
116
+ depth=19,
117
+ depth_single_blocks=38,
118
+ axes_dim=[16, 56, 56],
119
+ theta=10_000,
120
+ qkv_bias=True,
121
+ guidance_embed=True,
122
+ ),
123
+ ae_path=os.getenv("AE"),
124
+ ae_params=AutoEncoderParams(
125
+ resolution=256,
126
+ in_channels=3,
127
+ ch=128,
128
+ out_ch=3,
129
+ ch_mult=[1, 2, 4, 4],
130
+ num_res_blocks=2,
131
+ z_channels=16,
132
+ scale_factor=0.3611,
133
+ shift_factor=0.1159,
134
+ ),
135
+ ),
136
+ "flux-dev-fp8": ModelSpec(
137
+ repo_id="black-forest-labs/FLUX.1-dev",
138
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
139
+ repo_flow="flux1-dev.safetensors",
140
+ repo_ae="ae.safetensors",
141
+ ckpt_path=os.getenv("FLUX_DEV_FP8"),
142
+ params=FluxParams(
143
+ in_channels=64,
144
+ vec_in_dim=768,
145
+ context_in_dim=4096,
146
+ hidden_size=3072,
147
+ mlp_ratio=4.0,
148
+ num_heads=24,
149
+ depth=19,
150
+ depth_single_blocks=38,
151
+ axes_dim=[16, 56, 56],
152
+ theta=10_000,
153
+ qkv_bias=True,
154
+ guidance_embed=True,
155
+ ),
156
+ ae_path=os.getenv("AE"),
157
+ ae_params=AutoEncoderParams(
158
+ resolution=256,
159
+ in_channels=3,
160
+ ch=128,
161
+ out_ch=3,
162
+ ch_mult=[1, 2, 4, 4],
163
+ num_res_blocks=2,
164
+ z_channels=16,
165
+ scale_factor=0.3611,
166
+ shift_factor=0.1159,
167
+ ),
168
+ ),
169
+ "flux-schnell": ModelSpec(
170
+ repo_id="black-forest-labs/FLUX.1-schnell",
171
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
172
+ repo_flow="flux1-schnell.safetensors",
173
+ repo_ae="ae.safetensors",
174
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
175
+ params=FluxParams(
176
+ in_channels=64,
177
+ vec_in_dim=768,
178
+ context_in_dim=4096,
179
+ hidden_size=3072,
180
+ mlp_ratio=4.0,
181
+ num_heads=24,
182
+ depth=19,
183
+ depth_single_blocks=38,
184
+ axes_dim=[16, 56, 56],
185
+ theta=10_000,
186
+ qkv_bias=True,
187
+ guidance_embed=False,
188
+ ),
189
+ ae_path=os.getenv("AE"),
190
+ ae_params=AutoEncoderParams(
191
+ resolution=256,
192
+ in_channels=3,
193
+ ch=128,
194
+ out_ch=3,
195
+ ch_mult=[1, 2, 4, 4],
196
+ num_res_blocks=2,
197
+ z_channels=16,
198
+ scale_factor=0.3611,
199
+ shift_factor=0.1159,
200
+ ),
201
+ ),
202
+ }
203
+
204
+
205
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
206
+ if len(missing) > 0 and len(unexpected) > 0:
207
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
208
+ print("\n" + "-" * 79 + "\n")
209
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
210
+ elif len(missing) > 0:
211
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
212
+ elif len(unexpected) > 0:
213
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
214
+
215
+ def load_from_repo_id(repo_id, checkpoint_name):
216
+ ckpt_path = hf_hub_download(repo_id, checkpoint_name)
217
+ sd = load_sft(ckpt_path, device='cpu')
218
+ return sd
219
+
220
+ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
221
+ # Loading Flux
222
+ print("Init model")
223
+ ckpt_path = configs[name].ckpt_path
224
+ if (
225
+ ckpt_path is None
226
+ and configs[name].repo_id is not None
227
+ and configs[name].repo_flow is not None
228
+ and hf_download
229
+ ):
230
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
231
+
232
+ with torch.device("meta" if ckpt_path is not None else device):
233
+ model = Flux(configs[name].params).to(torch.bfloat16)
234
+
235
+ if ckpt_path is not None:
236
+ print("Loading checkpoint")
237
+ # load_sft doesn't support torch.device
238
+ sd = load_model(ckpt_path, device=str(device))
239
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
240
+ print_load_warning(missing, unexpected)
241
+ return model
242
+
243
+ def load_flow_model_only_lora(
244
+ name: str,
245
+ device: str | torch.device = "cuda",
246
+ hf_download: bool = False,
247
+ lora_rank: int = 16,
248
+ use_fp8: bool = False
249
+ ):
250
+ # Loading Flux
251
+
252
+ print("Init model")
253
+ ckpt_path = configs[name].ckpt_path
254
+ if (
255
+ ckpt_path is None
256
+ and configs[name].repo_id is not None
257
+ and configs[name].repo_flow is not None
258
+ and hf_download
259
+ ):
260
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
261
+
262
+ if hf_download:
263
+ try:
264
+ lora_ckpt_path = hf_hub_download("", "dit_lora.safetensors")
265
+ except:
266
+ lora_ckpt_path = os.environ.get("LORA", None)
267
+ else:
268
+ lora_ckpt_path = os.environ.get("LORA", None)
269
+
270
+ with torch.device("meta" if ckpt_path is not None else device):
271
+ model = Flux(configs[name].params)
272
+
273
+
274
+ model = set_lora(model, lora_rank, device="meta" if lora_ckpt_path is not None else device)
275
+
276
+ if ckpt_path is not None:
277
+ print("Loading lora")
278
+ lora_sd = load_sft(lora_ckpt_path, device=str(device)) if lora_ckpt_path.endswith("safetensors")\
279
+ else torch.load(lora_ckpt_path, map_location='cpu')
280
+
281
+ print("Loading main checkpoint")
282
+ # load_sft doesn't support torch.device
283
+
284
+ if ckpt_path.endswith('safetensors'):
285
+ if use_fp8:
286
+ print(
287
+ "####\n"
288
+ "We are in fp8 mode right now, since the fp8 checkpoint of XLabs-AI/flux-dev-fp8 seems broken\n"
289
+ "we convert the fp8 checkpoint on flight from bf16 checkpoint\n"
290
+ "If your storage is constrained"
291
+ "you can save the fp8 checkpoint and replace the bf16 checkpoint by yourself\n"
292
+ )
293
+ sd = load_sft(ckpt_path, device="cpu")
294
+ sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()}
295
+ else:
296
+ sd = load_sft(ckpt_path, device=str(device))
297
+
298
+ # for k in lora_sd:
299
+ # if isinstance(lora_sd[k], torch.Tensor):
300
+ # lora_sd[k] = lora_sd[k].to(device)
301
+
302
+ sd.update(lora_sd)
303
+ missing, unexpected = model.load_state_dict(sd, strict=True, assign=True)
304
+ else:
305
+ dit_state = torch.load(ckpt_path, map_location='cpu')
306
+ sd = {}
307
+ for k in dit_state.keys():
308
+ sd[k.replace('module.','')] = dit_state[k]
309
+ sd.update(lora_sd)
310
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
311
+
312
+ model.to(str(device))
313
+ print_load_warning(missing, unexpected)
314
+ return model
315
+
316
+
317
+ def set_lora(
318
+ model: Flux,
319
+ lora_rank: int,
320
+ double_blocks_indices: list[int] | None = None,
321
+ single_blocks_indices: list[int] | None = None,
322
+ device: str | torch.device = "cpu",
323
+ ) -> Flux:
324
+ double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices
325
+ single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \
326
+ else single_blocks_indices
327
+
328
+ lora_attn_procs = {}
329
+ with torch.device(device):
330
+ for name, attn_processor in model.attn_processors.items():
331
+ match = re.search(r'\.(\d+)\.', name)
332
+ if match:
333
+ layer_index = int(match.group(1))
334
+
335
+ if name.startswith("double_blocks") and layer_index in double_blocks_indices:
336
+ lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
337
+ elif name.startswith("single_blocks") and layer_index in single_blocks_indices:
338
+ lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
339
+ else:
340
+ lora_attn_procs[name] = attn_processor
341
+ model.set_attn_processor(lora_attn_procs)
342
+ return model
343
+
344
+
345
+ def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
346
+ # Loading Flux
347
+ from optimum.quanto import requantize
348
+ print("Init model")
349
+ ckpt_path = configs[name].ckpt_path
350
+ if (
351
+ ckpt_path is None
352
+ and configs[name].repo_id is not None
353
+ and configs[name].repo_flow is not None
354
+ and hf_download
355
+ ):
356
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
357
+ # json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
358
+
359
+
360
+ model = Flux(configs[name].params).to(torch.bfloat16)
361
+
362
+ print("Loading checkpoint")
363
+ # load_sft doesn't support torch.device
364
+ sd = load_sft(ckpt_path, device='cpu')
365
+ sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()}
366
+ model.load_state_dict(sd, assign=True)
367
+ return model
368
+ with open(json_path, "r") as f:
369
+ quantization_map = json.load(f)
370
+ print("Start a quantization process...")
371
+ requantize(model, sd, quantization_map, device=device)
372
+ print("Model is quantized!")
373
+ return model
374
+
375
+ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
376
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
377
+ version = os.environ.get("T5", "/root/filesystem/Destyle_OmniStyle/weights/xlabs-ai/xflux_text_encoders")
378
+ return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device)
379
+
380
+ def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
381
+ version = os.environ.get("CLIP", "/root/filesystem/Destyle_OmniStyle/weights/AI-ModelScope/clip-vit-large-patch14")
382
+ return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device)
383
+
384
+
385
+ def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
386
+ ckpt_path = configs[name].ae_path
387
+ if (
388
+ ckpt_path is None
389
+ and configs[name].repo_id is not None
390
+ and configs[name].repo_ae is not None
391
+ and hf_download
392
+ ):
393
+ ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
394
+
395
+ # Loading the autoencoder
396
+ print("Init AE")
397
+ with torch.device("meta" if ckpt_path is not None else device):
398
+ ae = AutoEncoder(configs[name].ae_params)
399
+
400
+ if ckpt_path is not None:
401
+ sd = load_sft(ckpt_path, device=str(device))
402
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
403
+ print_load_warning(missing, unexpected)
404
+ return ae
DST/dst/utils/convert_yaml_to_args_file.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import yaml
4
+
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument("--yaml", type=str, required=True)
7
+ parser.add_argument("--arg", type=str, required=True)
8
+ args = parser.parse_args()
9
+
10
+
11
+ with open(args.yaml, "r") as f:
12
+ data = yaml.safe_load(f)
13
+
14
+ with open(args.arg, "w") as f:
15
+ for k, v in data.items():
16
+ if isinstance(v, list):
17
+ v = list(map(str, v))
18
+ v = " ".join(v)
19
+ if v is None:
20
+ continue
21
+ print(f"--{k} {v}", end=" ", file=f)
DST/inference.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+ import dataclasses
5
+ from typing import Literal
6
+ from accelerate import Accelerator
7
+ from transformers import HfArgumentParser
8
+ from PIL import Image
9
+ from dst.flux.pipeline import DSTPipeline
10
+ from tqdm import tqdm
11
+
12
+ @dataclasses.dataclass
13
+ class InferenceArgs:
14
+ model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
15
+ width: int = 1024
16
+ height: int = 1024
17
+ ref_size: int = 1024
18
+ num_steps: int = 25
19
+ guidance: float = 4
20
+ seed: int = 0
21
+ only_lora: bool = True
22
+ concat_refs: bool = True
23
+ lora_rank: int = 512
24
+ pe: Literal['d', 'h', 'w', 'o'] = 'd'
25
+
26
+
27
+
28
+ def crop_if_not_square(img):
29
+ w, h = img.size
30
+ if w != h:
31
+ min_dim = min(w, h)
32
+ left = (w - min_dim) // 2
33
+ top = (h - min_dim) // 2
34
+ right = left + min_dim
35
+ bottom = top + min_dim
36
+ img = img.crop((left, top, right, bottom))
37
+ return img
38
+
39
+
40
+ def main(args: InferenceArgs):
41
+ accelerator = Accelerator()
42
+ device = accelerator.device
43
+
44
+ # test modern art images
45
+ test_cnt_folder = "./test/cnt/"
46
+ test_sty_folder = "./test/sty/"
47
+ # test real paintings
48
+ # test_cnt_folder = "./test/cnt_nga"
49
+ # test_sty_folder = "./test/sty_nga"
50
+ save_folder = "./output/"
51
+ os.makedirs(save_folder, exist_ok=True)
52
+
53
+ pipeline = DSTPipeline(
54
+ args.model_type,
55
+ device,
56
+ accelerator.state.deepspeed_plugin is not None,
57
+ only_lora=args.only_lora,
58
+ lora_rank=args.lora_rank
59
+ )
60
+
61
+ for sty_img in os.listdir(test_sty_folder):
62
+ for cnt_img in os.listdir(test_cnt_folder):
63
+
64
+ save_name = os.path.join(save_folder, f"{os.path.splitext(cnt_img)[0]}@{os.path.splitext(sty_img)[0]}.jpg")
65
+ # if os.path.exists(save_name):
66
+ # continue
67
+
68
+ cnt_path = os.path.join(test_cnt_folder, cnt_img)
69
+ sty_path = os.path.join(test_sty_folder, sty_img)
70
+
71
+ cnt_img_pil = Image.open(cnt_path).convert('RGB')
72
+ sty_img_pil = Image.open(sty_path).convert('RGB')
73
+ cnt_center_crop = crop_if_not_square(cnt_img_pil)
74
+ sty_center_crop = crop_if_not_square(sty_img_pil)
75
+
76
+ cnt_img_pil = cnt_center_crop.resize((args.width, args.height))
77
+ sty_img_pil = sty_center_crop.resize((args.width, args.height))
78
+
79
+
80
+ ref_imgs = [sty_img_pil, cnt_img_pil]
81
+
82
+ image_gen = pipeline(
83
+ prompt="",
84
+ width=args.width,
85
+ height=args.height,
86
+ guidance=args.guidance,
87
+ num_steps=args.num_steps,
88
+ seed=args.seed,
89
+ ref_imgs=ref_imgs,
90
+ pe=args.pe,
91
+ )
92
+
93
+ if args.concat_refs:
94
+ new_blank_img = Image.new('RGB', (args.width * 3, args.height))
95
+ new_blank_img.paste(cnt_img_pil, (0, 0))
96
+ new_blank_img.paste(sty_img_pil, (args.width, 0))
97
+ new_blank_img.paste(image_gen, (args.width * 2, 0))
98
+
99
+ new_blank_img.save(save_name)
100
+
101
+ if __name__ == "__main__":
102
+ parser = HfArgumentParser([InferenceArgs])
103
+ args = parser.parse_args_into_dataclasses()[0]
104
+ main(args)
DST/output/tower@American Comic_Architecture_Church or mosque_7f69557d-751b-4c7c-9495-5156b2513989_42.jpg ADDED

Git LFS Details

  • SHA256: ab80a6fa085cfb6a2e8fee7dff5525330bd475241db309610208cc556b31c261
  • Pointer size: 131 Bytes
  • Size of remote file: 573 kB
DST/output/tower@American Comic_Object_Backpack or bag_938b06c4-92bb-4535-a2cf-6112318d0c0d_42.jpg ADDED

Git LFS Details

  • SHA256: 74c9a0f62363ccaa547ea8ae7e56d58becc19569f82c7bd37c4d35a80f54c176
  • Pointer size: 131 Bytes
  • Size of remote file: 632 kB
DST/output/tower@Anime_04c5405f-fcaa-4065-899e-49149e2835e7.jpg ADDED

Git LFS Details

  • SHA256: a12bc99e1f759a3e066d831376bd587702e09141e4656947613e5aed90de1fa6
  • Pointer size: 131 Bytes
  • Size of remote file: 564 kB
DST/output/tower@Anime_Animal_Dog_3d4f8063-1ed9-4601-8c16-6f9f2e77c81b_42.jpg ADDED

Git LFS Details

  • SHA256: 17bdbe533d31aedd5f2b859c7f9c7bf4325b087d05940fa16689f6c7dc59993b
  • Pointer size: 131 Bytes
  • Size of remote file: 485 kB
DST/output/tower@Flat Design_Scene_Beach or coast_31be5a9d-5c07-4c60-be0c-6b0034e8244b_42.jpg ADDED

Git LFS Details

  • SHA256: a2cb837f921a708882f466d80c77ad4530f7149c4577d14e8151279b0f601d06
  • Pointer size: 131 Bytes
  • Size of remote file: 437 kB
DST/output/tower@Flat Design_b5f391f3-c7f4-4c21-8f35-a6f33df8eae5.jpg ADDED

Git LFS Details

  • SHA256: 2040051cf11f8fbc29901159c4034ce18534a940b5936bb0da69fd93632e7c30
  • Pointer size: 131 Bytes
  • Size of remote file: 450 kB
DST/output/tower@Ghibli_19ed1a7f-d7ef-49f6-9f7d-9d5961c385cc.jpg ADDED

Git LFS Details

  • SHA256: 31ee6507a1a1f55190d75093116ff078188abbd54b96b282b7f8eb27a49efaa5
  • Pointer size: 131 Bytes
  • Size of remote file: 571 kB
DST/output/tower@Graffiti_Scene_Forest scene_64c1b63c-d094-4f50-bf71-2ff3d0035dd8_42.jpg ADDED

Git LFS Details

  • SHA256: 8c3076d8a69b9cd042d52672c887f89fb57ec91c4c604b88f063591127e5bffb
  • Pointer size: 131 Bytes
  • Size of remote file: 637 kB
DST/output/tower@Line Art_3f6a80ef-5fa8-492a-a8b6-1b515e3ebd97.jpg ADDED

Git LFS Details

  • SHA256: 16078fd4f2042e5098522bac09dbd110c3e50fa3a6c6c71e884f128b3640cf27
  • Pointer size: 131 Bytes
  • Size of remote file: 993 kB
DST/output/tower@Linocut_15502303-0459-4a02-be7e-b5b93ba69c1e.jpg ADDED

Git LFS Details

  • SHA256: 01b9fcc4b79cb9690d58d9bdb74dae7a8c5a3ff2f4e5b908547d6e594604dea4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
DST/output/tower@Neon_Scene_Beach or coast_c42c6871-6268-41c9-8ec7-a36c999b5acb_42.jpg ADDED

Git LFS Details

  • SHA256: 1a83167c10b0f98b7b11192c35e9147b60dcf85135d02ba53bc8cd95219ac24c
  • Pointer size: 131 Bytes
  • Size of remote file: 556 kB
DST/output/tower@Pixel Art_8b869e57-7345-4f78-8d8b-07a2def7979c.jpg ADDED

Git LFS Details

  • SHA256: f6a8440accc7fe7884f1ba8ab407e28a6acf800db014f3e7ebd50a72d0740628
  • Pointer size: 131 Bytes
  • Size of remote file: 803 kB
DST/output/tower@Watercolor_e15d75e6-796f-4289-ae2e-a0b04ba1a5ea.jpg ADDED

Git LFS Details

  • SHA256: 8075967caabe446e134b5802862492d30dd965ea1c2c5532931a290510ccc721
  • Pointer size: 131 Bytes
  • Size of remote file: 582 kB
DST/readme.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ### 🛠️ Installation
3
+
4
+ 1. Install required packages:
5
+
6
+ ```bash
7
+ pip install -r requirements.txt
8
+ ```
9
+
10
+ ---
11
+
12
+ ### 📦 Download Pretrained Weights
13
+
14
+ 2. Before downloading the weights, you need to request access to **[FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)** on Hugging Face.
15
+
16
+ 3. Once approved, open `download_weights.sh` and replace `YOUR_TOKEN` with your Hugging Face token.
17
+
18
+ 4. Then run the following to download the weights:
19
+
20
+ ```bash
21
+ cd weights
22
+ bash download_weights.sh
23
+ ```
24
+
25
+ ---
26
+
27
+ ### 🚀 Inference
28
+
29
+ 5. Run inference using the provided script:
30
+
31
+ ```bash
32
+ bash test.sh
33
+ ```
34
+
35
+ ---
DST/requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.30.1
2
+ deepspeed==0.16.0
3
+ einops==0.8.0
4
+ transformers==4.43.3
5
+ huggingface-hub==0.24.5
6
+ optimum-quanto
7
+ datasets
8
+ omegaconf
9
+ diffusers
10
+ sentencepiece
11
+ opencv-python
12
+ matplotlib
13
+ onnxruntime
14
+ torchvision
15
+ timm
16
+ wandb
DST/run.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PATH=$PATH:/root/.local/bin
3
+ export FLUX_DEV="/root/filesystem/Destyle_OmniStyle/weights/AI-ModelScope/FLUX.1-dev/flux1-dev.safetensors"
4
+ export AE="/root/filesystem/Destyle_OmniStyle/weights/AI-ModelScope/FLUX.1-dev/ae.safetensors"
5
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
6
+
7
+
8
+
9
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --multi_gpu --num_processes 8 train.py
10
+ # CUDA_VISIBLE_DEVICES=0,1 accelerate launch --multi_gpu --num_processes 2 train.py
DST/save/1024_modernart/dit_lora.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fff826ffd83ad51e4cc3261438d7be1c55a1dfafc4a78435e26435fe33bd30a8
3
+ size 1912640152
DST/save/1024_nga/dit_lora.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3f0edd1d7f7cdea08d40d36c545357493b4c40934295fbd5802850c580f8289
3
+ size 1912640152
DST/test.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ export FLUX_DEV="./weights/FLUX.1-dev/flux1-dev.safetensors"
2
+ export AE="./weights/FLUX.1-dev/ae.safetensors"
3
+ export T5="./weights/xflux_text_encoders"
4
+ export CLIP="./weights/clip-vit-large-patch14"
5
+ export LORA="./save/1024_modernart/dit_lora.safetensors"
6
+
7
+ CUDA_VISIBLE_DEVICES=0 python inference.py
DST/test/cnt/tower.jpg ADDED

Git LFS Details

  • SHA256: e2086946ba39d4d02f71b44f1f4e058a1d65882825fce6460de249b37d93b320
  • Pointer size: 131 Bytes
  • Size of remote file: 234 kB
DST/test/cnt_nga/0field.jpeg ADDED
DST/test/cnt_nga/0rahul-chakraborty-9Wg7qAhGmnU-unsplash.jpg ADDED

Git LFS Details

  • SHA256: df3391564d5d2d0267b29e8dc0f03a12fdfcb54caec2ad01eccf572961e991b6
  • Pointer size: 131 Bytes
  • Size of remote file: 625 kB
DST/test/cnt_nga/0trip.jpg ADDED

Git LFS Details

  • SHA256: de22ad71d01345c33d0c2b7f55fcdaae99af303ab1de3bb90edb86814997c25c
  • Pointer size: 131 Bytes
  • Size of remote file: 117 kB
DST/test/cnt_nga/1mio-ito-DaGIjXNl5oA-unsplash.jpg ADDED

Git LFS Details

  • SHA256: d39954015331fe9e4d738145beed7679d432773da721303755212b5c60a4e327
  • Pointer size: 131 Bytes
  • Size of remote file: 672 kB
DST/test/sty/American Comic_Architecture_Church or mosque_7f69557d-751b-4c7c-9495-5156b2513989_42.png ADDED

Git LFS Details

  • SHA256: 43c85cdf0058036b395da0035727c5fda45b302d13abcc516a68ed9c53e63122
  • Pointer size: 132 Bytes
  • Size of remote file: 1.52 MB
DST/test/sty/American Comic_Object_Backpack or bag_938b06c4-92bb-4535-a2cf-6112318d0c0d_42.png ADDED

Git LFS Details

  • SHA256: d8f6125e05532d0ff95b9c970de1ee85385d7eee51727f909cdf3284eb58b421
  • Pointer size: 132 Bytes
  • Size of remote file: 1.32 MB
DST/test/sty/Anime_04c5405f-fcaa-4065-899e-49149e2835e7.png ADDED

Git LFS Details

  • SHA256: eb0cd470f2239111e30420f0855327fa244339f5a455df8148f91b6b1476750e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.05 MB
DST/test/sty/Anime_Animal_Dog_3d4f8063-1ed9-4601-8c16-6f9f2e77c81b_42.png ADDED

Git LFS Details

  • SHA256: f1e75634f6afcd7bf4f7c8aac134b148240136bb62e1e666c884c99c3c7b9990
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
DST/test/sty/Flat Design_Scene_Beach or coast_31be5a9d-5c07-4c60-be0c-6b0034e8244b_42.png ADDED

Git LFS Details

  • SHA256: 3b5fd093bb41acc78a990036ce71a990269f758e22810cddf95e5021fb38402b
  • Pointer size: 131 Bytes
  • Size of remote file: 823 kB
DST/test/sty/Flat Design_b5f391f3-c7f4-4c21-8f35-a6f33df8eae5.png ADDED

Git LFS Details

  • SHA256: e2f98ba755fb374d3679eb4d0e5382d8d4e19987562b00e541e18c3c0556dc9e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.38 MB
DST/test/sty/Ghibli_19ed1a7f-d7ef-49f6-9f7d-9d5961c385cc.png ADDED

Git LFS Details

  • SHA256: 872899f743201010626636b0238fc97d3bae8221fc65db993bdb983655fc165c
  • Pointer size: 132 Bytes
  • Size of remote file: 2 MB
DST/test/sty/Graffiti_Scene_Forest scene_64c1b63c-d094-4f50-bf71-2ff3d0035dd8_42.png ADDED

Git LFS Details

  • SHA256: 876f1712f4434f3ca0dfe666012badc74edf2f0132a139f56c705618aa856c96
  • Pointer size: 132 Bytes
  • Size of remote file: 1.54 MB
DST/test/sty/Line Art_3f6a80ef-5fa8-492a-a8b6-1b515e3ebd97.png ADDED

Git LFS Details

  • SHA256: 91708209bda0a2e5d6b7eb898ca4fb995647d8a6e85e73d4cf736c984d07501e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.59 MB
DST/test/sty/Linocut_15502303-0459-4a02-be7e-b5b93ba69c1e.png ADDED

Git LFS Details

  • SHA256: 9c8c7ee41d8874bc5df5f0cce994246e3085fd286d39abe4fb5933509fc7d7f4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.61 MB
DST/test/sty/Neon_Scene_Beach or coast_c42c6871-6268-41c9-8ec7-a36c999b5acb_42.png ADDED

Git LFS Details

  • SHA256: 1abd8effa219b0b6997962bb33246f260cdf5595d2b206cc9161de697a825e5c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.49 MB