b3h-young123 commited on
Commit
71e893e
·
verified ·
1 Parent(s): 3a13adf

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. OmniGen/.gitattributes +38 -0
  2. OmniGen/.gitignore +2 -0
  3. OmniGen/LICENSE +21 -0
  4. OmniGen/OmniGen/__init__.py +4 -0
  5. OmniGen/OmniGen/__pycache__/__init__.cpython-310.pyc +0 -0
  6. OmniGen/OmniGen/__pycache__/model.cpython-310.pyc +0 -0
  7. OmniGen/OmniGen/__pycache__/pipeline.cpython-310.pyc +0 -0
  8. OmniGen/OmniGen/__pycache__/processor.cpython-310.pyc +0 -0
  9. OmniGen/OmniGen/__pycache__/scheduler.cpython-310.pyc +0 -0
  10. OmniGen/OmniGen/__pycache__/transformer.cpython-310.pyc +0 -0
  11. OmniGen/OmniGen/__pycache__/utils.cpython-310.pyc +0 -0
  12. OmniGen/OmniGen/model.py +406 -0
  13. OmniGen/OmniGen/pipeline.py +307 -0
  14. OmniGen/OmniGen/processor.py +338 -0
  15. OmniGen/OmniGen/scheduler.py +181 -0
  16. OmniGen/OmniGen/train.py +0 -0
  17. OmniGen/OmniGen/train_helper/__init__.py +2 -0
  18. OmniGen/OmniGen/train_helper/data.py +116 -0
  19. OmniGen/OmniGen/train_helper/loss.py +68 -0
  20. OmniGen/OmniGen/transformer.py +194 -0
  21. OmniGen/OmniGen/utils.py +110 -0
  22. OmniGen/README.md +20 -0
  23. OmniGen/app.py +408 -0
  24. OmniGen/docs/fine-tuning.md +172 -0
  25. OmniGen/docs/inference.md +167 -0
  26. OmniGen/imgs/.DS_Store +0 -0
  27. OmniGen/imgs/demo_cases.png +3 -0
  28. OmniGen/imgs/demo_cases/AI_Pioneers.jpg +3 -0
  29. OmniGen/imgs/demo_cases/edit.png +3 -0
  30. OmniGen/imgs/demo_cases/entity.png +3 -0
  31. OmniGen/imgs/demo_cases/reasoning.png +3 -0
  32. OmniGen/imgs/demo_cases/same_pose.png +3 -0
  33. OmniGen/imgs/demo_cases/skeletal.png +3 -0
  34. OmniGen/imgs/demo_cases/skeletal2img.png +3 -0
  35. OmniGen/imgs/demo_cases/t2i_woman_with_book.png +3 -0
  36. OmniGen/imgs/overall.jpg +3 -0
  37. OmniGen/imgs/referring.png +3 -0
  38. OmniGen/imgs/test_cases/1.jpg +3 -0
  39. OmniGen/imgs/test_cases/2.jpg +3 -0
  40. OmniGen/imgs/test_cases/3.jpg +3 -0
  41. OmniGen/imgs/test_cases/4.jpg +3 -0
  42. OmniGen/imgs/test_cases/Amanda.jpg +3 -0
  43. OmniGen/imgs/test_cases/cat.jpeg +3 -0
  44. OmniGen/imgs/test_cases/control.jpg +3 -0
  45. OmniGen/imgs/test_cases/guitar1.png +3 -0
  46. OmniGen/imgs/test_cases/icl1.jpg +3 -0
  47. OmniGen/imgs/test_cases/icl2.jpg +3 -0
  48. OmniGen/imgs/test_cases/icl3.jpg +3 -0
  49. OmniGen/imgs/test_cases/img1.jpg +3 -0
  50. OmniGen/imgs/test_cases/img2.jpg +3 -0
OmniGen/.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ imgs/** filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.png filter=lfs diff=lfs merge=lfs -text
OmniGen/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.ipynb
2
+ **/__pycache__/
OmniGen/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 VectorSpaceLab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
OmniGen/OmniGen/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .model import OmniGen
2
+ from .processor import OmniGenProcessor
3
+ from .scheduler import OmniGenScheduler
4
+ from .pipeline import OmniGenPipeline
OmniGen/OmniGen/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (344 Bytes). View file
 
OmniGen/OmniGen/__pycache__/model.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
OmniGen/OmniGen/__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (8.64 kB). View file
 
OmniGen/OmniGen/__pycache__/processor.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
OmniGen/OmniGen/__pycache__/scheduler.cpython-310.pyc ADDED
Binary file (2.77 kB). View file
 
OmniGen/OmniGen/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (3.97 kB). View file
 
OmniGen/OmniGen/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.54 kB). View file
 
OmniGen/OmniGen/model.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The code is revised from DiT
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import math
7
+ from typing import Dict
8
+
9
+ from diffusers.loaders import PeftAdapterMixin
10
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
11
+ from huggingface_hub import snapshot_download
12
+ from safetensors.torch import load_file
13
+
14
+ from OmniGen.transformer import Phi3Config, Phi3Transformer
15
+
16
+
17
+ def modulate(x, shift, scale):
18
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
19
+
20
+
21
+ class TimestepEmbedder(nn.Module):
22
+ """
23
+ Embeds scalar timesteps into vector representations.
24
+ """
25
+ def __init__(self, hidden_size, frequency_embedding_size=256):
26
+ super().__init__()
27
+ self.mlp = nn.Sequential(
28
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
29
+ nn.SiLU(),
30
+ nn.Linear(hidden_size, hidden_size, bias=True),
31
+ )
32
+ self.frequency_embedding_size = frequency_embedding_size
33
+
34
+ @staticmethod
35
+ def timestep_embedding(t, dim, max_period=10000):
36
+ """
37
+ Create sinusoidal timestep embeddings.
38
+ :param t: a 1-D Tensor of N indices, one per batch element.
39
+ These may be fractional.
40
+ :param dim: the dimension of the output.
41
+ :param max_period: controls the minimum frequency of the embeddings.
42
+ :return: an (N, D) Tensor of positional embeddings.
43
+ """
44
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
45
+ half = dim // 2
46
+ freqs = torch.exp(
47
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
48
+ ).to(device=t.device)
49
+ args = t[:, None].float() * freqs[None]
50
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
51
+ if dim % 2:
52
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
53
+ return embedding
54
+
55
+ def forward(self, t, dtype=torch.float32):
56
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
57
+ t_emb = self.mlp(t_freq)
58
+ return t_emb
59
+
60
+
61
+ class FinalLayer(nn.Module):
62
+ """
63
+ The final layer of DiT.
64
+ """
65
+ def __init__(self, hidden_size, patch_size, out_channels):
66
+ super().__init__()
67
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
68
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
69
+ self.adaLN_modulation = nn.Sequential(
70
+ nn.SiLU(),
71
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
72
+ )
73
+
74
+ def forward(self, x, c):
75
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
76
+ x = modulate(self.norm_final(x), shift, scale)
77
+ x = self.linear(x)
78
+ return x
79
+
80
+
81
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
82
+ """
83
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
84
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
85
+ """
86
+ if isinstance(grid_size, int):
87
+ grid_size = (grid_size, grid_size)
88
+
89
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
90
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
91
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
92
+ grid = np.stack(grid, axis=0)
93
+
94
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
95
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
96
+ if cls_token and extra_tokens > 0:
97
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
98
+ return pos_embed
99
+
100
+
101
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
102
+ assert embed_dim % 2 == 0
103
+
104
+ # use half of dimensions to encode grid_h
105
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
106
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
107
+
108
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
109
+ return emb
110
+
111
+
112
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
113
+ """
114
+ embed_dim: output dimension for each position
115
+ pos: a list of positions to be encoded: size (M,)
116
+ out: (M, D)
117
+ """
118
+ assert embed_dim % 2 == 0
119
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
120
+ omega /= embed_dim / 2.
121
+ omega = 1. / 10000**omega # (D/2,)
122
+
123
+ pos = pos.reshape(-1) # (M,)
124
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
125
+
126
+ emb_sin = np.sin(out) # (M, D/2)
127
+ emb_cos = np.cos(out) # (M, D/2)
128
+
129
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
130
+ return emb
131
+
132
+
133
+ class PatchEmbedMR(nn.Module):
134
+ """ 2D Image to Patch Embedding
135
+ """
136
+ def __init__(
137
+ self,
138
+ patch_size: int = 2,
139
+ in_chans: int = 4,
140
+ embed_dim: int = 768,
141
+ bias: bool = True,
142
+ ):
143
+ super().__init__()
144
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
145
+
146
+ def forward(self, x):
147
+ x = self.proj(x)
148
+ x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
149
+ return x
150
+
151
+
152
+ class OmniGen(nn.Module, PeftAdapterMixin):
153
+ """
154
+ Diffusion model with a Transformer backbone.
155
+ """
156
+ def __init__(
157
+ self,
158
+ transformer_config: Phi3Config,
159
+ patch_size=2,
160
+ in_channels=4,
161
+ pe_interpolation: float = 1.0,
162
+ pos_embed_max_size: int = 192,
163
+ ):
164
+ super().__init__()
165
+ self.in_channels = in_channels
166
+ self.out_channels = in_channels
167
+ self.patch_size = patch_size
168
+ self.pos_embed_max_size = pos_embed_max_size
169
+
170
+ hidden_size = transformer_config.hidden_size
171
+
172
+ self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
173
+ self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
174
+
175
+ self.time_token = TimestepEmbedder(hidden_size)
176
+ self.t_embedder = TimestepEmbedder(hidden_size)
177
+
178
+ self.pe_interpolation = pe_interpolation
179
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
180
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
181
+
182
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
183
+
184
+ self.initialize_weights()
185
+
186
+ self.llm = Phi3Transformer(config=transformer_config)
187
+ self.llm.config.use_cache = False
188
+
189
+ @classmethod
190
+ def from_pretrained(cls, model_name):
191
+ if not os.path.exists(model_name):
192
+ cache_folder = os.getenv('HF_HUB_CACHE')
193
+ model_name = snapshot_download(repo_id=model_name,
194
+ cache_dir=cache_folder,
195
+ ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
196
+ config = Phi3Config.from_pretrained(model_name)
197
+ model = cls(config)
198
+ if os.path.exists(os.path.join(model_name, 'model.safetensors')):
199
+ print("Loading safetensors")
200
+ ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
201
+ else:
202
+ ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
203
+ model.load_state_dict(ckpt)
204
+ return model
205
+
206
+ def initialize_weights(self):
207
+ assert not hasattr(self, "llama")
208
+
209
+ # Initialize transformer layers:
210
+ def _basic_init(module):
211
+ if isinstance(module, nn.Linear):
212
+ torch.nn.init.xavier_uniform_(module.weight)
213
+ if module.bias is not None:
214
+ nn.init.constant_(module.bias, 0)
215
+ self.apply(_basic_init)
216
+
217
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
218
+ w = self.x_embedder.proj.weight.data
219
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
220
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
221
+
222
+ w = self.input_x_embedder.proj.weight.data
223
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
224
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
225
+
226
+
227
+ # Initialize timestep embedding MLP:
228
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
229
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
230
+ nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
231
+ nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
232
+
233
+ # Zero-out output layers:
234
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
235
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
236
+ nn.init.constant_(self.final_layer.linear.weight, 0)
237
+ nn.init.constant_(self.final_layer.linear.bias, 0)
238
+
239
+ def unpatchify(self, x, h, w):
240
+ """
241
+ x: (N, T, patch_size**2 * C)
242
+ imgs: (N, H, W, C)
243
+ """
244
+ c = self.out_channels
245
+
246
+ x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
247
+ x = torch.einsum('nhwpqc->nchpwq', x)
248
+ imgs = x.reshape(shape=(x.shape[0], c, h, w))
249
+ return imgs
250
+
251
+
252
+ def cropped_pos_embed(self, height, width):
253
+ """Crops positional embeddings for SD3 compatibility."""
254
+ if self.pos_embed_max_size is None:
255
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
256
+
257
+ height = height // self.patch_size
258
+ width = width // self.patch_size
259
+ if height > self.pos_embed_max_size:
260
+ raise ValueError(
261
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
262
+ )
263
+ if width > self.pos_embed_max_size:
264
+ raise ValueError(
265
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
266
+ )
267
+
268
+ top = (self.pos_embed_max_size - height) // 2
269
+ left = (self.pos_embed_max_size - width) // 2
270
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
271
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
272
+ # print(top, top + height, left, left + width, spatial_pos_embed.size())
273
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
274
+ return spatial_pos_embed
275
+
276
+
277
+ def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
278
+ if isinstance(latents, list):
279
+ return_list = False
280
+ if padding_latent is None:
281
+ padding_latent = [None] * len(latents)
282
+ return_list = True
283
+ patched_latents, num_tokens, shapes = [], [], []
284
+ for latent, padding in zip(latents, padding_latent):
285
+ height, width = latent.shape[-2:]
286
+ if is_input_images:
287
+ latent = self.input_x_embedder(latent)
288
+ else:
289
+ latent = self.x_embedder(latent)
290
+ pos_embed = self.cropped_pos_embed(height, width)
291
+ latent = latent + pos_embed
292
+ if padding is not None:
293
+ latent = torch.cat([latent, padding], dim=-2)
294
+ patched_latents.append(latent)
295
+
296
+ num_tokens.append(pos_embed.size(1))
297
+ shapes.append([height, width])
298
+ if not return_list:
299
+ latents = torch.cat(patched_latents, dim=0)
300
+ else:
301
+ latents = patched_latents
302
+ else:
303
+ height, width = latents.shape[-2:]
304
+ if is_input_images:
305
+ latents = self.input_x_embedder(latents)
306
+ else:
307
+ latents = self.x_embedder(latents)
308
+ pos_embed = self.cropped_pos_embed(height, width)
309
+ latents = latents + pos_embed
310
+ num_tokens = latents.size(1)
311
+ shapes = [height, width]
312
+ return latents, num_tokens, shapes
313
+
314
+
315
+ def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
316
+ """
317
+
318
+ """
319
+ input_is_list = isinstance(x, list)
320
+ x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
321
+ time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
322
+
323
+ if input_img_latents is not None:
324
+ input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
325
+ if input_ids is not None:
326
+ condition_embeds = self.llm.embed_tokens(input_ids).clone()
327
+ input_img_inx = 0
328
+ for b_inx in input_image_sizes.keys():
329
+ for start_inx, end_inx in input_image_sizes[b_inx]:
330
+ condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
331
+ input_img_inx += 1
332
+ if input_img_latents is not None:
333
+ assert input_img_inx == len(input_latents)
334
+
335
+ input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
336
+ else:
337
+ input_emb = torch.cat([time_token, x], dim=1)
338
+ output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
339
+ output, past_key_values = output.last_hidden_state, output.past_key_values
340
+ if input_is_list:
341
+ image_embedding = output[:, -max(num_tokens):]
342
+ time_emb = self.t_embedder(timestep, dtype=x.dtype)
343
+ x = self.final_layer(image_embedding, time_emb)
344
+ latents = []
345
+ for i in range(x.size(0)):
346
+ latent = x[i:i+1, :num_tokens[i]]
347
+ latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
348
+ latents.append(latent)
349
+ else:
350
+ image_embedding = output[:, -num_tokens:]
351
+ time_emb = self.t_embedder(timestep, dtype=x.dtype)
352
+ x = self.final_layer(image_embedding, time_emb)
353
+ latents = self.unpatchify(x, shapes[0], shapes[1])
354
+
355
+ if return_past_key_values:
356
+ return latents, past_key_values
357
+ return latents
358
+
359
+ @torch.no_grad()
360
+ def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
361
+ self.llm.config.use_cache = use_kv_cache
362
+ model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model)
363
+ if use_img_cfg:
364
+ cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
365
+ cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
366
+ model_out = [cond, cond, cond]
367
+ else:
368
+ cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
369
+ cond = uncond + cfg_scale * (cond - uncond)
370
+ model_out = [cond, cond]
371
+
372
+ return torch.cat(model_out, dim=0), past_key_values
373
+
374
+
375
+ @torch.no_grad()
376
+ def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
377
+ self.llm.config.use_cache = use_kv_cache
378
+ if past_key_values is None:
379
+ past_key_values = [None] * len(attention_mask)
380
+
381
+ x = torch.split(x, len(x) // len(attention_mask), dim=0)
382
+ timestep = timestep.to(x[0].dtype)
383
+ timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
384
+
385
+ model_out, pask_key_values = [], []
386
+ for i in range(len(input_ids)):
387
+ temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
388
+ model_out.append(temp_out)
389
+ pask_key_values.append(temp_pask_key_values)
390
+
391
+ if len(model_out) == 3:
392
+ cond, uncond, img_cond = model_out
393
+ cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
394
+ model_out = [cond, cond, cond]
395
+ elif len(model_out) == 2:
396
+ cond, uncond = model_out
397
+ cond = uncond + cfg_scale * (cond - uncond)
398
+ model_out = [cond, cond]
399
+ else:
400
+ return model_out[0]
401
+
402
+ return torch.cat(model_out, dim=0), pask_key_values
403
+
404
+
405
+
406
+
OmniGen/OmniGen/pipeline.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import inspect
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+ import gc
5
+
6
+ from PIL import Image
7
+ import numpy as np
8
+ import torch
9
+ from huggingface_hub import snapshot_download
10
+ from peft import LoraConfig, PeftModel
11
+ from diffusers.models import AutoencoderKL
12
+ from diffusers.utils import (
13
+ USE_PEFT_BACKEND,
14
+ is_torch_xla_available,
15
+ logging,
16
+ replace_example_docstring,
17
+ scale_lora_layers,
18
+ unscale_lora_layers,
19
+ )
20
+ from safetensors.torch import load_file
21
+
22
+ from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ EXAMPLE_DOC_STRING = """
28
+ Examples:
29
+ ```py
30
+ >>> from OmniGen import OmniGenPipeline
31
+ >>> pipe = FluxControlNetPipeline.from_pretrained(
32
+ ... base_model
33
+ ... )
34
+ >>> prompt = "A woman holds a bouquet of flowers and faces the camera"
35
+ >>> image = pipe(
36
+ ... prompt,
37
+ ... guidance_scale=2.5,
38
+ ... num_inference_steps=50,
39
+ ... ).images[0]
40
+ >>> image.save("t2i.png")
41
+ ```
42
+ """
43
+
44
+
45
+ 90
46
+ class OmniGenPipeline:
47
+ def __init__(
48
+ self,
49
+ vae: AutoencoderKL,
50
+ model: OmniGen,
51
+ processor: OmniGenProcessor,
52
+ ):
53
+ self.vae = vae
54
+ self.model = model
55
+ self.processor = processor
56
+
57
+ if torch.cuda.is_available():
58
+ self.device = torch.device("cuda")
59
+ elif torch.backends.mps.is_available():
60
+ self.device = torch.device("mps")
61
+ else:
62
+ logger.info("Don't detect any available devices, using CPU instead")
63
+ self.device = torch.device("cpu")
64
+
65
+ self.model.to(torch.bfloat16)
66
+ self.model.eval()
67
+ self.vae.eval()
68
+
69
+ self.model_cpu_offload = False
70
+
71
+ @classmethod
72
+ def from_pretrained(cls, model_name, vae_path: str=None):
73
+ if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"):
74
+ logger.info("Model not found, downloading...")
75
+ cache_folder = os.getenv('HF_HUB_CACHE')
76
+ model_name = snapshot_download(repo_id=model_name,
77
+ cache_dir=cache_folder,
78
+ ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt'])
79
+ logger.info(f"Downloaded model to {model_name}")
80
+ model = OmniGen.from_pretrained(model_name)
81
+ processor = OmniGenProcessor.from_pretrained(model_name)
82
+
83
+ if os.path.exists(os.path.join(model_name, "vae")):
84
+ vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae"))
85
+ elif vae_path is not None:
86
+ vae = AutoencoderKL.from_pretrained(vae_path).to(device)
87
+ else:
88
+ logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF")
89
+ vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
90
+
91
+ return cls(vae, model, processor)
92
+
93
+ def merge_lora(self, lora_path: str):
94
+ model = PeftModel.from_pretrained(self.model, lora_path)
95
+ model.merge_and_unload()
96
+
97
+ self.model = model
98
+
99
+ def to(self, device: Union[str, torch.device]):
100
+ if isinstance(device, str):
101
+ device = torch.device(device)
102
+ self.model.to(device)
103
+ self.vae.to(device)
104
+ self.device = device
105
+
106
+ def vae_encode(self, x, dtype):
107
+ if self.vae.config.shift_factor is not None:
108
+ x = self.vae.encode(x).latent_dist.sample()
109
+ x = (x - self.vae.config.shift_factor) * self.vae.config.scaling_factor
110
+ else:
111
+ x = self.vae.encode(x).latent_dist.sample().mul_(self.vae.config.scaling_factor)
112
+ x = x.to(dtype)
113
+ return x
114
+
115
+ def move_to_device(self, data):
116
+ if isinstance(data, list):
117
+ return [x.to(self.device) for x in data]
118
+ return data.to(self.device)
119
+
120
+ def enable_model_cpu_offload(self):
121
+ self.model_cpu_offload = True
122
+ self.model.to("cpu")
123
+ self.vae.to("cpu")
124
+ torch.cuda.empty_cache() # Clear VRAM
125
+ gc.collect() # Run garbage collection to free system RAM
126
+
127
+ def disable_model_cpu_offload(self):
128
+ self.model_cpu_offload = False
129
+ self.model.to(self.device)
130
+ self.vae.to(self.device)
131
+
132
+ @torch.no_grad()
133
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
134
+ def __call__(
135
+ self,
136
+ prompt: Union[str, List[str]],
137
+ input_images: Union[List[str], List[List[str]]] = None,
138
+ height: int = 1024,
139
+ width: int = 1024,
140
+ num_inference_steps: int = 50,
141
+ guidance_scale: float = 3,
142
+ use_img_guidance: bool = True,
143
+ img_guidance_scale: float = 1.6,
144
+ max_input_image_size: int = 1024,
145
+ separate_cfg_infer: bool = True,
146
+ offload_model: bool = False,
147
+ use_kv_cache: bool = True,
148
+ offload_kv_cache: bool = True,
149
+ use_input_image_size_as_output: bool = False,
150
+ dtype: torch.dtype = torch.bfloat16,
151
+ seed: int = None,
152
+ ):
153
+ r"""
154
+ Function invoked when calling the pipeline for generation.
155
+
156
+ Args:
157
+ prompt (`str` or `List[str]`):
158
+ The prompt or prompts to guide the image generation.
159
+ input_images (`List[str]` or `List[List[str]]`, *optional*):
160
+ The list of input images. We will replace the "<|image_i|>" in prompt with the 1-th image in list.
161
+ height (`int`, *optional*, defaults to 1024):
162
+ The height in pixels of the generated image. The number must be a multiple of 16.
163
+ width (`int`, *optional*, defaults to 1024):
164
+ The width in pixels of the generated image. The number must be a multiple of 16.
165
+ num_inference_steps (`int`, *optional*, defaults to 50):
166
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
167
+ guidance_scale (`float`, *optional*, defaults to 4.0):
168
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
169
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
170
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
171
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
172
+ usually at the expense of lower image quality.
173
+ use_img_guidance (`bool`, *optional*, defaults to True):
174
+ Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
175
+ img_guidance_scale (`float`, *optional*, defaults to 1.6):
176
+ Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
177
+ max_input_image_size (`int`, *optional*, defaults to 1024): the maximum size of input image, which will be used to crop the input image to the maximum size
178
+ separate_cfg_infer (`bool`, *optional*, defaults to False):
179
+ Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
180
+ use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
181
+ offload_kv_cache (`bool`, *optional*, defaults to True): offload the cached key and value to cpu, which can save memory but slow down the generation silightly
182
+ offload_model (`bool`, *optional*, defaults to False): offload the model to cpu, which can save memory but slow down the generation
183
+ use_input_image_size_as_output (bool, defaults to False): whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task
184
+ seed (`int`, *optional*):
185
+ A random seed for generating output.
186
+ dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
187
+ data type for the model
188
+ Examples:
189
+
190
+ Returns:
191
+ A list with the generated images.
192
+ """
193
+ # check inputs:
194
+ if use_input_image_size_as_output:
195
+ assert isinstance(prompt, str) and len(input_images) == 1, "if you want to make sure the output image have the same size as the input image, please only input one image instead of multiple input images"
196
+ else:
197
+ assert height%16 == 0 and width%16 == 0, "The height and width must be a multiple of 16."
198
+ if input_images is None:
199
+ use_img_guidance = False
200
+ if isinstance(prompt, str):
201
+ prompt = [prompt]
202
+ input_images = [input_images] if input_images is not None else None
203
+
204
+ # set model and processor
205
+ if max_input_image_size != self.processor.max_image_size:
206
+ self.processor = OmniGenProcessor(self.processor.text_tokenizer, max_image_size=max_input_image_size)
207
+ if offload_model:
208
+ self.enable_model_cpu_offload()
209
+ else:
210
+ self.disable_model_cpu_offload()
211
+
212
+ input_data = self.processor(prompt, input_images, height=height, width=width, use_img_cfg=use_img_guidance, separate_cfg_input=separate_cfg_infer, use_input_image_size_as_output=use_input_image_size_as_output)
213
+
214
+ num_prompt = len(prompt)
215
+ num_cfg = 2 if use_img_guidance else 1
216
+ if use_input_image_size_as_output:
217
+ if separate_cfg_infer:
218
+ height, width = input_data['input_pixel_values'][0][0].shape[-2:]
219
+ else:
220
+ height, width = input_data['input_pixel_values'][0].shape[-2:]
221
+ latent_size_h, latent_size_w = height//8, width//8
222
+
223
+ if seed is not None:
224
+ generator = torch.Generator(device=self.device).manual_seed(seed)
225
+ else:
226
+ generator = None
227
+ latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator)
228
+ latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
229
+
230
+ if input_images is not None and self.model_cpu_offload: self.vae.to(self.device)
231
+ input_img_latents = []
232
+ if separate_cfg_infer:
233
+ for temp_pixel_values in input_data['input_pixel_values']:
234
+ temp_input_latents = []
235
+ for img in temp_pixel_values:
236
+ img = self.vae_encode(img.to(self.device), dtype)
237
+ temp_input_latents.append(img)
238
+ input_img_latents.append(temp_input_latents)
239
+ else:
240
+ for img in input_data['input_pixel_values']:
241
+ img = self.vae_encode(img.to(self.device), dtype)
242
+ input_img_latents.append(img)
243
+ if input_images is not None and self.model_cpu_offload:
244
+ self.vae.to('cpu')
245
+ torch.cuda.empty_cache() # Clear VRAM
246
+ gc.collect() # Run garbage collection to free system RAM
247
+
248
+ model_kwargs = dict(input_ids=self.move_to_device(input_data['input_ids']),
249
+ input_img_latents=input_img_latents,
250
+ input_image_sizes=input_data['input_image_sizes'],
251
+ attention_mask=self.move_to_device(input_data["attention_mask"]),
252
+ position_ids=self.move_to_device(input_data["position_ids"]),
253
+ cfg_scale=guidance_scale,
254
+ img_cfg_scale=img_guidance_scale,
255
+ use_img_cfg=use_img_guidance,
256
+ use_kv_cache=use_kv_cache,
257
+ offload_model=offload_model,
258
+ )
259
+
260
+ if separate_cfg_infer:
261
+ func = self.model.forward_with_separate_cfg
262
+ else:
263
+ func = self.model.forward_with_cfg
264
+ self.model.to(dtype)
265
+
266
+ if self.model_cpu_offload:
267
+ for name, param in self.model.named_parameters():
268
+ if 'layers' in name and 'layers.0' not in name:
269
+ param.data = param.data.cpu()
270
+ else:
271
+ param.data = param.data.to(self.device)
272
+ for buffer_name, buffer in self.model.named_buffers():
273
+ setattr(self.model, buffer_name, buffer.to(self.device))
274
+ # else:
275
+ # self.model.to(self.device)
276
+
277
+ scheduler = OmniGenScheduler(num_steps=num_inference_steps)
278
+ samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache)
279
+ samples = samples.chunk((1+num_cfg), dim=0)[0]
280
+
281
+ if self.model_cpu_offload:
282
+ self.model.to('cpu')
283
+ torch.cuda.empty_cache()
284
+ gc.collect()
285
+
286
+ self.vae.to(self.device)
287
+ samples = samples.to(torch.float32)
288
+ if self.vae.config.shift_factor is not None:
289
+ samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
290
+ else:
291
+ samples = samples / self.vae.config.scaling_factor
292
+ samples = self.vae.decode(samples).sample
293
+
294
+ if self.model_cpu_offload:
295
+ self.vae.to('cpu')
296
+ torch.cuda.empty_cache()
297
+ gc.collect()
298
+
299
+ output_samples = (samples * 0.5 + 0.5).clamp(0, 1)*255
300
+ output_samples = output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
301
+ output_images = []
302
+ for i, sample in enumerate(output_samples):
303
+ output_images.append(Image.fromarray(sample))
304
+
305
+ torch.cuda.empty_cache() # Clear VRAM
306
+ gc.collect() # Run garbage collection to free system RAM
307
+ return output_images
OmniGen/OmniGen/processor.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from typing import Dict, List
4
+ import json
5
+
6
+ import torch
7
+ import numpy as np
8
+ import random
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ from transformers import AutoTokenizer
12
+ from huggingface_hub import snapshot_download
13
+
14
+ from OmniGen.utils import (
15
+ create_logger,
16
+ update_ema,
17
+ requires_grad,
18
+ center_crop_arr,
19
+ crop_arr,
20
+ )
21
+
22
+
23
+
24
+
25
+ class OmniGenProcessor:
26
+ def __init__(self,
27
+ text_tokenizer,
28
+ max_image_size: int=1024):
29
+ self.text_tokenizer = text_tokenizer
30
+ self.max_image_size = max_image_size
31
+
32
+ self.image_transform = transforms.Compose([
33
+ transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
36
+ ])
37
+
38
+ self.collator = OmniGenCollator()
39
+ self.separate_collator = OmniGenSeparateCollator()
40
+
41
+ @classmethod
42
+ def from_pretrained(cls, model_name):
43
+ if not os.path.exists(model_name):
44
+ cache_folder = os.getenv('HF_HUB_CACHE')
45
+ model_name = snapshot_download(repo_id=model_name,
46
+ cache_dir=cache_folder,
47
+ allow_patterns="*.json")
48
+ text_tokenizer = AutoTokenizer.from_pretrained(model_name)
49
+
50
+ return cls(text_tokenizer)
51
+
52
+
53
+ def process_image(self, image):
54
+ image = Image.open(image).convert('RGB')
55
+ return self.image_transform(image)
56
+
57
+ def process_multi_modal_prompt(self, text, input_images):
58
+ text = self.add_prefix_instruction(text)
59
+ if input_images is None or len(input_images) == 0:
60
+ model_inputs = self.text_tokenizer(text)
61
+ return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
62
+
63
+ pattern = r"<\|image_\d+\|>"
64
+ prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
65
+
66
+ for i in range(1, len(prompt_chunks)):
67
+ if prompt_chunks[i][0] == 1:
68
+ prompt_chunks[i] = prompt_chunks[i][1:]
69
+
70
+ image_tags = re.findall(pattern, text)
71
+ image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
72
+
73
+ unique_image_ids = sorted(list(set(image_ids)))
74
+ assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
75
+ # total images must be the same as the number of image tags
76
+ assert len(unique_image_ids) == len(input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
77
+
78
+ input_images = [input_images[x-1] for x in image_ids]
79
+
80
+ all_input_ids = []
81
+ img_inx = []
82
+ idx = 0
83
+ for i in range(len(prompt_chunks)):
84
+ all_input_ids.extend(prompt_chunks[i])
85
+ if i != len(prompt_chunks) -1:
86
+ start_inx = len(all_input_ids)
87
+ size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
88
+ img_inx.append([start_inx, start_inx+size])
89
+ all_input_ids.extend([0]*size)
90
+
91
+ return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
92
+
93
+
94
+ def add_prefix_instruction(self, prompt):
95
+ user_prompt = '<|user|>\n'
96
+ generation_prompt = 'Generate an image according to the following instructions\n'
97
+ assistant_prompt = '<|assistant|>\n<|diffusion|>'
98
+ prompt_suffix = "<|end|>\n"
99
+ prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
100
+ return prompt
101
+
102
+
103
+ def __call__(self,
104
+ instructions: List[str],
105
+ input_images: List[List[str]] = None,
106
+ height: int = 1024,
107
+ width: int = 1024,
108
+ negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
109
+ use_img_cfg: bool = True,
110
+ separate_cfg_input: bool = False,
111
+ use_input_image_size_as_output: bool=False,
112
+ ) -> Dict:
113
+
114
+ if input_images is None:
115
+ use_img_cfg = False
116
+ if isinstance(instructions, str):
117
+ instructions = [instructions]
118
+ input_images = [input_images]
119
+
120
+ input_data = []
121
+ for i in range(len(instructions)):
122
+ cur_instruction = instructions[i]
123
+ cur_input_images = None if input_images is None else input_images[i]
124
+ if cur_input_images is not None and len(cur_input_images) > 0:
125
+ cur_input_images = [self.process_image(x) for x in cur_input_images]
126
+ else:
127
+ cur_input_images = None
128
+ assert "<img><|image_1|></img>" not in cur_instruction
129
+
130
+ mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
131
+
132
+
133
+ neg_mllm_input, img_cfg_mllm_input = None, None
134
+ neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
135
+ if use_img_cfg:
136
+ if cur_input_images is not None and len(cur_input_images) >= 1:
137
+ img_cfg_prompt = [f"<img><|image_{i+1}|></img>" for i in range(len(cur_input_images))]
138
+ img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
139
+ else:
140
+ img_cfg_mllm_input = neg_mllm_input
141
+
142
+ if use_input_image_size_as_output:
143
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [mllm_input['pixel_values'][0].size(-2), mllm_input['pixel_values'][0].size(-1)]))
144
+ else:
145
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
146
+
147
+ if separate_cfg_input:
148
+ return self.separate_collator(input_data)
149
+ return self.collator(input_data)
150
+
151
+
152
+
153
+
154
+ class OmniGenCollator:
155
+ def __init__(self, pad_token_id=2, hidden_size=3072):
156
+ self.pad_token_id = pad_token_id
157
+ self.hidden_size = hidden_size
158
+
159
+ def create_position(self, attention_mask, num_tokens_for_output_images):
160
+ position_ids = []
161
+ text_length = attention_mask.size(-1)
162
+ img_length = max(num_tokens_for_output_images)
163
+ for mask in attention_mask:
164
+ temp_l = torch.sum(mask)
165
+ temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more token
166
+ position_ids.append(temp_position)
167
+ return torch.LongTensor(position_ids)
168
+
169
+ def create_mask(self, attention_mask, num_tokens_for_output_images):
170
+ extended_mask = []
171
+ padding_images = []
172
+ text_length = attention_mask.size(-1)
173
+ img_length = max(num_tokens_for_output_images)
174
+ seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
175
+ inx = 0
176
+ for mask in attention_mask:
177
+ temp_l = torch.sum(mask)
178
+ pad_l = text_length - temp_l
179
+
180
+ temp_mask = torch.tril(torch.ones(size=(temp_l+1, temp_l+1)))
181
+
182
+ image_mask = torch.zeros(size=(temp_l+1, img_length))
183
+ temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
184
+
185
+ image_mask = torch.ones(size=(img_length, temp_l+img_length+1))
186
+ temp_mask = torch.cat([temp_mask, image_mask], dim=0)
187
+
188
+ if pad_l > 0:
189
+ pad_mask = torch.zeros(size=(temp_l+1+img_length, pad_l))
190
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
191
+
192
+ pad_mask = torch.ones(size=(pad_l, seq_len))
193
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
194
+
195
+ true_img_length = num_tokens_for_output_images[inx]
196
+ pad_img_length = img_length - true_img_length
197
+ if pad_img_length > 0:
198
+ temp_mask[:, -pad_img_length:] = 0
199
+ temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
200
+ else:
201
+ temp_padding_imgs = None
202
+
203
+ extended_mask.append(temp_mask.unsqueeze(0))
204
+ padding_images.append(temp_padding_imgs)
205
+ inx += 1
206
+ return torch.cat(extended_mask, dim=0), padding_images
207
+
208
+ def adjust_attention_for_input_images(self, attention_mask, image_sizes):
209
+ for b_inx in image_sizes.keys():
210
+ for start_inx, end_inx in image_sizes[b_inx]:
211
+ attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
212
+
213
+ return attention_mask
214
+
215
+ def pad_input_ids(self, input_ids, image_sizes):
216
+ max_l = max([len(x) for x in input_ids])
217
+ padded_ids = []
218
+ attention_mask = []
219
+ new_image_sizes = []
220
+
221
+ for i in range(len(input_ids)):
222
+ temp_ids = input_ids[i]
223
+ temp_l = len(temp_ids)
224
+ pad_l = max_l - temp_l
225
+ if pad_l == 0:
226
+ attention_mask.append([1]*max_l)
227
+ padded_ids.append(temp_ids)
228
+ else:
229
+ attention_mask.append([0]*pad_l+[1]*temp_l)
230
+ padded_ids.append([self.pad_token_id]*pad_l+temp_ids)
231
+
232
+ if i in image_sizes:
233
+ new_inx = []
234
+ for old_inx in image_sizes[i]:
235
+ new_inx.append([x+pad_l for x in old_inx])
236
+ image_sizes[i] = new_inx
237
+
238
+ return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
239
+
240
+
241
+ def process_mllm_input(self, mllm_inputs, target_img_size):
242
+ num_tokens_for_output_images = []
243
+ for img_size in target_img_size:
244
+ num_tokens_for_output_images.append(img_size[0]*img_size[1]//16//16)
245
+
246
+ pixel_values, image_sizes = [], {}
247
+ b_inx = 0
248
+ for x in mllm_inputs:
249
+ if x['pixel_values'] is not None:
250
+ pixel_values.extend(x['pixel_values'])
251
+ for size in x['image_sizes']:
252
+ if b_inx not in image_sizes:
253
+ image_sizes[b_inx] = [size]
254
+ else:
255
+ image_sizes[b_inx].append(size)
256
+ b_inx += 1
257
+ pixel_values = [x.unsqueeze(0) for x in pixel_values]
258
+
259
+
260
+ input_ids = [x['input_ids'] for x in mllm_inputs]
261
+ padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
262
+ position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
263
+ attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
264
+ attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
265
+
266
+ return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
267
+
268
+
269
+ def __call__(self, features):
270
+ mllm_inputs = [f[0] for f in features]
271
+ cfg_mllm_inputs = [f[1] for f in features]
272
+ img_cfg_mllm_input = [f[2] for f in features]
273
+ target_img_size = [f[3] for f in features]
274
+
275
+
276
+ if img_cfg_mllm_input[0] is not None:
277
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
278
+ target_img_size = target_img_size + target_img_size + target_img_size
279
+ else:
280
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs
281
+ target_img_size = target_img_size + target_img_size
282
+
283
+
284
+ all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
285
+
286
+ data = {"input_ids": all_padded_input_ids,
287
+ "attention_mask": all_attention_mask,
288
+ "position_ids": all_position_ids,
289
+ "input_pixel_values": all_pixel_values,
290
+ "input_image_sizes": all_image_sizes,
291
+ "padding_images": all_padding_images,
292
+ }
293
+ return data
294
+
295
+
296
+ class OmniGenSeparateCollator(OmniGenCollator):
297
+ def __call__(self, features):
298
+ mllm_inputs = [f[0] for f in features]
299
+ cfg_mllm_inputs = [f[1] for f in features]
300
+ img_cfg_mllm_input = [f[2] for f in features]
301
+ target_img_size = [f[3] for f in features]
302
+
303
+ all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
304
+
305
+
306
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
307
+ all_padded_input_ids.append(padded_input_ids)
308
+ all_attention_mask.append(attention_mask)
309
+ all_position_ids.append(position_ids)
310
+ all_pixel_values.append(pixel_values)
311
+ all_image_sizes.append(image_sizes)
312
+ all_padding_images.append(padding_images)
313
+
314
+ if cfg_mllm_inputs[0] is not None:
315
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(cfg_mllm_inputs, target_img_size)
316
+ all_padded_input_ids.append(padded_input_ids)
317
+ all_attention_mask.append(attention_mask)
318
+ all_position_ids.append(position_ids)
319
+ all_pixel_values.append(pixel_values)
320
+ all_image_sizes.append(image_sizes)
321
+ all_padding_images.append(padding_images)
322
+ if img_cfg_mllm_input[0] is not None:
323
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(img_cfg_mllm_input, target_img_size)
324
+ all_padded_input_ids.append(padded_input_ids)
325
+ all_attention_mask.append(attention_mask)
326
+ all_position_ids.append(position_ids)
327
+ all_pixel_values.append(pixel_values)
328
+ all_image_sizes.append(image_sizes)
329
+ all_padding_images.append(padding_images)
330
+
331
+ data = {"input_ids": all_padded_input_ids,
332
+ "attention_mask": all_attention_mask,
333
+ "position_ids": all_position_ids,
334
+ "input_pixel_values": all_pixel_values,
335
+ "input_image_sizes": all_image_sizes,
336
+ "padding_images": all_padding_images,
337
+ }
338
+ return data
OmniGen/OmniGen/scheduler.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from typing import Optional, Dict, Any, Tuple, List
3
+ import gc
4
+
5
+ import torch
6
+ from transformers.cache_utils import Cache, DynamicCache, OffloadedCache
7
+
8
+
9
+
10
+ class OmniGenCache(DynamicCache):
11
+ def __init__(self,
12
+ num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
13
+ if not torch.cuda.is_available():
14
+ raise RuntimeError("OffloadedCache can only be used with a GPU")
15
+ super().__init__()
16
+ self.original_device = []
17
+ self.prefetch_stream = torch.cuda.Stream()
18
+ self.num_tokens_for_img = num_tokens_for_img
19
+ self.offload_kv_cache = offload_kv_cache
20
+
21
+ def prefetch_layer(self, layer_idx: int):
22
+ "Starts prefetching the next layer cache"
23
+ if layer_idx < len(self):
24
+ with torch.cuda.stream(self.prefetch_stream):
25
+ # Prefetch next layer tensors to GPU
26
+ device = self.original_device[layer_idx]
27
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
28
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
29
+
30
+
31
+ def evict_previous_layer(self, layer_idx: int):
32
+ "Moves the previous layer cache to the CPU"
33
+ if len(self) > 2:
34
+ # We do it on the default stream so it occurs after all earlier computations on these tensors are done
35
+ if layer_idx == 0:
36
+ prev_layer_idx = -1
37
+ else:
38
+ prev_layer_idx = (layer_idx - 1) % len(self)
39
+ self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
40
+ self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
41
+
42
+
43
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
44
+ "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
45
+ if layer_idx < len(self):
46
+ if self.offload_kv_cache:
47
+ # Evict the previous layer if necessary
48
+ torch.cuda.current_stream().synchronize()
49
+ self.evict_previous_layer(layer_idx)
50
+ # Load current layer cache to its original device if not already there
51
+ original_device = self.original_device[layer_idx]
52
+ # self.prefetch_stream.synchronize(original_device)
53
+ torch.cuda.synchronize(self.prefetch_stream)
54
+ key_tensor = self.key_cache[layer_idx]
55
+ value_tensor = self.value_cache[layer_idx]
56
+
57
+ # Prefetch the next layer
58
+ self.prefetch_layer((layer_idx + 1) % len(self))
59
+ else:
60
+ key_tensor = self.key_cache[layer_idx]
61
+ value_tensor = self.value_cache[layer_idx]
62
+ return (key_tensor, value_tensor)
63
+ else:
64
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
65
+
66
+
67
+ def update(
68
+ self,
69
+ key_states: torch.Tensor,
70
+ value_states: torch.Tensor,
71
+ layer_idx: int,
72
+ cache_kwargs: Optional[Dict[str, Any]] = None,
73
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
74
+ """
75
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
76
+ Parameters:
77
+ key_states (`torch.Tensor`):
78
+ The new key states to cache.
79
+ value_states (`torch.Tensor`):
80
+ The new value states to cache.
81
+ layer_idx (`int`):
82
+ The index of the layer to cache the states for.
83
+ cache_kwargs (`Dict[str, Any]`, `optional`):
84
+ Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
85
+ Return:
86
+ A tuple containing the updated key and value states.
87
+ """
88
+ # Update the cache
89
+ if len(self.key_cache) < layer_idx:
90
+ raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
91
+ elif len(self.key_cache) == layer_idx:
92
+ # only cache the states for condition tokens
93
+ key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
94
+ value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
95
+
96
+ # Update the number of seen tokens
97
+ if layer_idx == 0:
98
+ self._seen_tokens += key_states.shape[-2]
99
+
100
+ self.key_cache.append(key_states)
101
+ self.value_cache.append(value_states)
102
+ self.original_device.append(key_states.device)
103
+ if self.offload_kv_cache:
104
+ self.evict_previous_layer(layer_idx)
105
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
106
+ else:
107
+ # only cache the states for condition tokens
108
+ key_tensor, value_tensor = self[layer_idx]
109
+ k = torch.cat([key_tensor, key_states], dim=-2)
110
+ v = torch.cat([value_tensor, value_states], dim=-2)
111
+ return k, v
112
+
113
+
114
+
115
+ class OmniGenScheduler:
116
+ def __init__(self, num_steps: int=50, time_shifting_factor: int=1):
117
+ self.num_steps = num_steps
118
+ self.time_shift = time_shifting_factor
119
+
120
+ t = torch.linspace(0, 1, num_steps+1)
121
+ t = t / (t + time_shifting_factor - time_shifting_factor * t)
122
+ self.sigma = t
123
+
124
+ def crop_kv_cache(self, past_key_values, num_tokens_for_img):
125
+ # return
126
+ crop_past_key_values = ()
127
+ for layer_idx in range(len(past_key_values)):
128
+ key_states, value_states = past_key_values[layer_idx][:2]
129
+ crop_past_key_values += ((key_states[..., :-(num_tokens_for_img+1), :], value_states[..., :-(num_tokens_for_img+1), :], ),)
130
+ # return crop_past_key_values
131
+ return DynamicCache.from_legacy_cache(crop_past_key_values)
132
+
133
+ def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
134
+ if isinstance(position_ids, list):
135
+ for i in range(len(position_ids)):
136
+ position_ids[i] = position_ids[i][:, -(num_tokens_for_img+1):]
137
+ else:
138
+ position_ids = position_ids[:, -(num_tokens_for_img+1):]
139
+ return position_ids
140
+
141
+ def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
142
+ if isinstance(attention_mask, list):
143
+ return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
144
+ return attention_mask[..., -(num_tokens_for_img+1):, :]
145
+
146
+ def crop_cache(self, cache, num_tokens_for_img):
147
+ for i in range(len(cache.key_cache)):
148
+ cache.key_cache[i] = cache.key_cache[i][..., :-(num_tokens_for_img+1), :]
149
+ cache.value_cache[i] = cache.value_cache[i][..., :-(num_tokens_for_img+1), :]
150
+
151
+ return cache
152
+
153
+ def __call__(self, z, func, model_kwargs, use_kv_cache: bool=True, offload_kv_cache: bool=True):
154
+ num_tokens_for_img = z.size(-1)*z.size(-2) // 4
155
+ if isinstance(model_kwargs['input_ids'], list):
156
+ cache = [OmniGenCache(num_tokens_for_img, offload_kv_cache) for _ in range(len(model_kwargs['input_ids']))] if use_kv_cache else None
157
+ else:
158
+ cache = OmniGenCache(num_tokens_for_img, offload_kv_cache) if use_kv_cache else None
159
+ results = {}
160
+ for i in tqdm(range(self.num_steps)):
161
+ timesteps = torch.zeros(size=(len(z), )).to(z.device) + self.sigma[i]
162
+ pred, cache = func(z, timesteps, past_key_values=cache, **model_kwargs)
163
+ sigma_next = self.sigma[i+1]
164
+ sigma = self.sigma[i]
165
+ z = z + (sigma_next - sigma) * pred
166
+ if i == 0 and use_kv_cache:
167
+ num_tokens_for_img = z.size(-1)*z.size(-2) // 4
168
+ if isinstance(cache, list):
169
+ model_kwargs['input_ids'] = [None] * len(cache)
170
+ else:
171
+ model_kwargs['input_ids'] = None
172
+
173
+ model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
174
+ model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
175
+
176
+ del cache
177
+ torch.cuda.empty_cache()
178
+ gc.collect()
179
+ return z
180
+
181
+
OmniGen/OmniGen/train.py ADDED
File without changes
OmniGen/OmniGen/train_helper/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .data import DatasetFromJson, TrainDataCollator
2
+ from .loss import training_losses
OmniGen/OmniGen/train_helper/data.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datasets
3
+ from datasets import load_dataset, ClassLabel, concatenate_datasets
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+ from PIL import Image
8
+ import json
9
+ import copy
10
+ # import torchvision.transforms as T
11
+ from torchvision import transforms
12
+ import pickle
13
+ import re
14
+
15
+ from OmniGen import OmniGenProcessor
16
+ from OmniGen.processor import OmniGenCollator
17
+
18
+
19
+ class DatasetFromJson(torch.utils.data.Dataset):
20
+ def __init__(
21
+ self,
22
+ json_file: str,
23
+ image_path: str,
24
+ processer: OmniGenProcessor,
25
+ image_transform,
26
+ max_input_length_limit: int = 18000,
27
+ condition_dropout_prob: float = 0.1,
28
+ keep_raw_resolution: bool = True,
29
+ ):
30
+
31
+ self.image_transform = image_transform
32
+ self.processer = processer
33
+ self.condition_dropout_prob = condition_dropout_prob
34
+ self.max_input_length_limit = max_input_length_limit
35
+ self.keep_raw_resolution = keep_raw_resolution
36
+
37
+ self.data = load_dataset('json', data_files=json_file)['train']
38
+ self.image_path = image_path
39
+
40
+ def process_image(self, image_file):
41
+ if self.image_path is not None:
42
+ image_file = os.path.join(self.image_path, image_file)
43
+ image = Image.open(image_file).convert('RGB')
44
+ return self.image_transform(image)
45
+
46
+ def get_example(self, index):
47
+ example = self.data[index]
48
+
49
+ instruction, input_images, output_image = example['instruction'], example['input_images'], example['output_image']
50
+ if random.random() < self.condition_dropout_prob:
51
+ instruction = '<cfg>'
52
+ input_images = None
53
+ if input_images is not None:
54
+ input_images = [self.process_image(x) for x in input_images]
55
+ mllm_input = self.processer.process_multi_modal_prompt(instruction, input_images)
56
+
57
+ output_image = self.process_image(output_image)
58
+
59
+ return (mllm_input, output_image)
60
+
61
+
62
+ def __getitem__(self, index):
63
+ return self.get_example(index)
64
+ for _ in range(8):
65
+ try:
66
+ mllm_input, output_image = self.get_example(index)
67
+ if len(mllm_input['input_ids']) > self.max_input_length_limit:
68
+ raise RuntimeError(f"cur number of tokens={len(mllm_input['input_ids'])}, larger than max_input_length_limit={self.max_input_length_limit}")
69
+ return mllm_input, output_image
70
+ except Exception as e:
71
+ print("error when loading data: ", e)
72
+ print(self.data[index])
73
+ index = random.randint(0, len(self.data)-1)
74
+ raise RuntimeError("Too many bad data.")
75
+
76
+
77
+ def __len__(self):
78
+ return len(self.data)
79
+
80
+
81
+
82
+ class TrainDataCollator(OmniGenCollator):
83
+ def __init__(self, pad_token_id: int, hidden_size: int, keep_raw_resolution: bool):
84
+ self.pad_token_id = pad_token_id
85
+ self.hidden_size = hidden_size
86
+ self.keep_raw_resolution = keep_raw_resolution
87
+
88
+ def __call__(self, features):
89
+ mllm_inputs = [f[0] for f in features]
90
+
91
+ output_images = [f[1].unsqueeze(0) for f in features]
92
+ target_img_size = [[x.size(-2), x.size(-1)] for x in output_images]
93
+
94
+ all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
95
+
96
+ if not self.keep_raw_resolution:
97
+ output_image = torch.cat(output_image, dim=0)
98
+ if len(pixel_values) > 0:
99
+ all_pixel_values = torch.cat(all_pixel_values, dim=0)
100
+ else:
101
+ all_pixel_values = None
102
+
103
+ data = {"input_ids": all_padded_input_ids,
104
+ "attention_mask": all_attention_mask,
105
+ "position_ids": all_position_ids,
106
+ "input_pixel_values": all_pixel_values,
107
+ "input_image_sizes": all_image_sizes,
108
+ "padding_images": all_padding_images,
109
+ "output_images": output_images,
110
+ }
111
+ return data
112
+
113
+
114
+
115
+
116
+
OmniGen/OmniGen/train_helper/loss.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def sample_x0(x1):
5
+ """Sampling x0 & t based on shape of x1 (if needed)
6
+ Args:
7
+ x1 - data point; [batch, *dim]
8
+ """
9
+ if isinstance(x1, (list, tuple)):
10
+ x0 = [torch.randn_like(img_start) for img_start in x1]
11
+ else:
12
+ x0 = torch.randn_like(x1)
13
+
14
+ return x0
15
+
16
+ def sample_timestep(x1):
17
+ u = torch.normal(mean=0.0, std=1.0, size=(len(x1),))
18
+ t = 1 / (1 + torch.exp(-u))
19
+ t = t.to(x1[0])
20
+ return t
21
+
22
+
23
+ def training_losses(model, x1, model_kwargs=None, snr_type='uniform'):
24
+ """Loss for training torche score model
25
+ Args:
26
+ - model: backbone model; could be score, noise, or velocity
27
+ - x1: datapoint
28
+ - model_kwargs: additional arguments for torche model
29
+ """
30
+ if model_kwargs == None:
31
+ model_kwargs = {}
32
+
33
+ B = len(x1)
34
+
35
+ x0 = sample_x0(x1)
36
+ t = sample_timestep(x1)
37
+
38
+ if isinstance(x1, (list, tuple)):
39
+ xt = [t[i] * x1[i] + (1 - t[i]) * x0[i] for i in range(B)]
40
+ ut = [x1[i] - x0[i] for i in range(B)]
41
+ else:
42
+ dims = [1] * (len(x1.size()) - 1)
43
+ t_ = t.view(t.size(0), *dims)
44
+ xt = t_ * x1 + (1 - t_) * x0
45
+ ut = x1 - x0
46
+
47
+ model_output = model(xt, t, **model_kwargs)
48
+
49
+ terms = {}
50
+
51
+ if isinstance(x1, (list, tuple)):
52
+ assert len(model_output) == len(ut) == len(x1)
53
+ for i in range(B):
54
+ terms["loss"] = torch.stack(
55
+ [((ut[i] - model_output[i]) ** 2).mean() for i in range(B)],
56
+ dim=0,
57
+ )
58
+ else:
59
+ terms["loss"] = mean_flat(((model_output - ut) ** 2))
60
+
61
+ return terms
62
+
63
+
64
+ def mean_flat(x):
65
+ """
66
+ Take torche mean over all non-batch dimensions.
67
+ """
68
+ return torch.mean(x, dim=list(range(1, len(x.size()))))
OmniGen/OmniGen/transformer.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9
+ from huggingface_hub import snapshot_download
10
+
11
+ from transformers.modeling_outputs import (
12
+ BaseModelOutputWithPast,
13
+ CausalLMOutputWithPast,
14
+ SequenceClassifierOutputWithPast,
15
+ TokenClassifierOutput,
16
+ )
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers import Phi3Config, Phi3Model
19
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
20
+ from transformers.utils import logging
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class Phi3Transformer(Phi3Model):
26
+ """
27
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
28
+ We only modified the attention mask
29
+ Args:
30
+ config: Phi3Config
31
+ """
32
+ def prefetch_layer(self, layer_idx: int, device: torch.device):
33
+ "Starts prefetching the next layer cache"
34
+ with torch.cuda.stream(self.prefetch_stream):
35
+ # Prefetch next layer tensors to GPU
36
+ for name, param in self.layers[layer_idx].named_parameters():
37
+ param.data = param.data.to(device, non_blocking=True)
38
+
39
+ def evict_previous_layer(self, layer_idx: int):
40
+ "Moves the previous layer cache to the CPU"
41
+ prev_layer_idx = layer_idx - 1
42
+ for name, param in self.layers[prev_layer_idx].named_parameters():
43
+ param.data = param.data.to("cpu", non_blocking=True)
44
+
45
+ def get_offlaod_layer(self, layer_idx: int, device: torch.device):
46
+ # init stream
47
+ if not hasattr(self, "prefetch_stream"):
48
+ self.prefetch_stream = torch.cuda.Stream()
49
+
50
+ # delete previous layer
51
+ torch.cuda.current_stream().synchronize()
52
+ self.evict_previous_layer(layer_idx)
53
+
54
+ # make sure the current layer is ready
55
+ torch.cuda.synchronize(self.prefetch_stream)
56
+
57
+ # load next layer
58
+ self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
59
+
60
+
61
+ def forward(
62
+ self,
63
+ input_ids: torch.LongTensor = None,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ position_ids: Optional[torch.LongTensor] = None,
66
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
67
+ inputs_embeds: Optional[torch.FloatTensor] = None,
68
+ use_cache: Optional[bool] = None,
69
+ output_attentions: Optional[bool] = None,
70
+ output_hidden_states: Optional[bool] = None,
71
+ return_dict: Optional[bool] = None,
72
+ cache_position: Optional[torch.LongTensor] = None,
73
+ offload_model: Optional[bool] = False,
74
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
75
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
76
+ output_hidden_states = (
77
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
78
+ )
79
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
80
+
81
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
82
+
83
+ if (input_ids is None) ^ (inputs_embeds is not None):
84
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
85
+
86
+ if self.gradient_checkpointing and self.training:
87
+ if use_cache:
88
+ logger.warning_once(
89
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
90
+ )
91
+ use_cache = False
92
+
93
+ # kept for BC (non `Cache` `past_key_values` inputs)
94
+ return_legacy_cache = False
95
+ if use_cache and not isinstance(past_key_values, Cache):
96
+ return_legacy_cache = True
97
+ if past_key_values is None:
98
+ past_key_values = DynamicCache()
99
+ else:
100
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
101
+ logger.warning_once(
102
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
103
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
104
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
105
+ )
106
+
107
+ # if inputs_embeds is None:
108
+ # inputs_embeds = self.embed_tokens(input_ids)
109
+
110
+ # if cache_position is None:
111
+ # past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
112
+ # cache_position = torch.arange(
113
+ # past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
114
+ # )
115
+ # if position_ids is None:
116
+ # position_ids = cache_position.unsqueeze(0)
117
+
118
+ if attention_mask is not None and attention_mask.dim() == 3:
119
+ dtype = inputs_embeds.dtype
120
+ min_dtype = torch.finfo(dtype).min
121
+ attention_mask = (1 - attention_mask) * min_dtype
122
+ attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
123
+ else:
124
+ raise
125
+ # causal_mask = self._update_causal_mask(
126
+ # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
127
+ # )
128
+
129
+ hidden_states = inputs_embeds
130
+
131
+ # decoder layers
132
+ all_hidden_states = () if output_hidden_states else None
133
+ all_self_attns = () if output_attentions else None
134
+ next_decoder_cache = None
135
+
136
+ layer_idx = -1
137
+ for decoder_layer in self.layers:
138
+ layer_idx += 1
139
+
140
+ if output_hidden_states:
141
+ all_hidden_states += (hidden_states,)
142
+
143
+ if self.gradient_checkpointing and self.training:
144
+ layer_outputs = self._gradient_checkpointing_func(
145
+ decoder_layer.__call__,
146
+ hidden_states,
147
+ attention_mask,
148
+ position_ids,
149
+ past_key_values,
150
+ output_attentions,
151
+ use_cache,
152
+ cache_position,
153
+ )
154
+ else:
155
+ if offload_model and not self.training:
156
+ self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
157
+ layer_outputs = decoder_layer(
158
+ hidden_states,
159
+ attention_mask=attention_mask,
160
+ position_ids=position_ids,
161
+ past_key_value=past_key_values,
162
+ output_attentions=output_attentions,
163
+ use_cache=use_cache,
164
+ cache_position=cache_position,
165
+ )
166
+
167
+ hidden_states = layer_outputs[0]
168
+
169
+ if use_cache:
170
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
171
+
172
+ if output_attentions:
173
+ all_self_attns += (layer_outputs[1],)
174
+
175
+ hidden_states = self.norm(hidden_states)
176
+
177
+ # add hidden states from the last decoder layer
178
+ if output_hidden_states:
179
+ print('************')
180
+ all_hidden_states += (hidden_states,)
181
+
182
+ next_cache = next_decoder_cache if use_cache else None
183
+ if return_legacy_cache:
184
+ next_cache = next_cache.to_legacy_cache()
185
+
186
+ if not return_dict:
187
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
188
+ return BaseModelOutputWithPast(
189
+ last_hidden_state=hidden_states,
190
+ past_key_values=next_cache,
191
+ hidden_states=all_hidden_states,
192
+ attentions=all_self_attns,
193
+ )
194
+
OmniGen/OmniGen/utils.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from PIL import Image
4
+ import torch
5
+ import numpy as np
6
+
7
+ def create_logger(logging_dir):
8
+ """
9
+ Create a logger that writes to a log file and stdout.
10
+ """
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format='[\033[34m%(asctime)s\033[0m] %(message)s',
14
+ datefmt='%Y-%m-%d %H:%M:%S',
15
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+ return logger
19
+
20
+
21
+ @torch.no_grad()
22
+ def update_ema(ema_model, model, decay=0.9999):
23
+ """
24
+ Step the EMA model towards the current model.
25
+ """
26
+ ema_params = dict(ema_model.named_parameters())
27
+ for name, param in model.named_parameters():
28
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
29
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
30
+
31
+
32
+
33
+
34
+ def requires_grad(model, flag=True):
35
+ """
36
+ Set requires_grad flag for all parameters in a model.
37
+ """
38
+ for p in model.parameters():
39
+ p.requires_grad = flag
40
+
41
+
42
+ def center_crop_arr(pil_image, image_size):
43
+ """
44
+ Center cropping implementation from ADM.
45
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
46
+ """
47
+ while min(*pil_image.size) >= 2 * image_size:
48
+ pil_image = pil_image.resize(
49
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
50
+ )
51
+
52
+ scale = image_size / min(*pil_image.size)
53
+ pil_image = pil_image.resize(
54
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
55
+ )
56
+
57
+ arr = np.array(pil_image)
58
+ crop_y = (arr.shape[0] - image_size) // 2
59
+ crop_x = (arr.shape[1] - image_size) // 2
60
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
61
+
62
+
63
+
64
+ def crop_arr(pil_image, max_image_size):
65
+ while min(*pil_image.size) >= 2 * max_image_size:
66
+ pil_image = pil_image.resize(
67
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
68
+ )
69
+
70
+ if max(*pil_image.size) > max_image_size:
71
+ scale = max_image_size / max(*pil_image.size)
72
+ pil_image = pil_image.resize(
73
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
74
+ )
75
+
76
+ if min(*pil_image.size) < 16:
77
+ scale = 16 / min(*pil_image.size)
78
+ pil_image = pil_image.resize(
79
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
80
+ )
81
+
82
+ arr = np.array(pil_image)
83
+ crop_y1 = (arr.shape[0] % 16) // 2
84
+ crop_y2 = arr.shape[0] % 16 - crop_y1
85
+
86
+ crop_x1 = (arr.shape[1] % 16) // 2
87
+ crop_x2 = arr.shape[1] % 16 - crop_x1
88
+
89
+ arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2]
90
+ return Image.fromarray(arr)
91
+
92
+
93
+
94
+ def vae_encode(vae, x, weight_dtype):
95
+ if x is not None:
96
+ if vae.config.shift_factor is not None:
97
+ x = vae.encode(x).latent_dist.sample()
98
+ x = (x - vae.config.shift_factor) * vae.config.scaling_factor
99
+ else:
100
+ x = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor)
101
+ x = x.to(weight_dtype)
102
+ return x
103
+
104
+ def vae_encode_list(vae, x, weight_dtype):
105
+ latents = []
106
+ for img in x:
107
+ img = vae_encode(vae, img, weight_dtype)
108
+ latents.append(img)
109
+ return latents
110
+
OmniGen/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: OmniGen
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.4.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ tags:
12
+ - dwpose
13
+ - pose
14
+ - Text-to-Image
15
+ - Image-to-Image
16
+ - language models
17
+ - LLMs
18
+ short_description: Image generator/identifier/reposer
19
+ ---
20
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
OmniGen/app.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import os
4
+ import random
5
+ import spaces
6
+
7
+ from OmniGen import OmniGenPipeline
8
+
9
+ pipe = OmniGenPipeline.from_pretrained(
10
+ "Shitao/OmniGen-v1"
11
+ )
12
+
13
+ @spaces.GPU(duration=180)
14
+ def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model,
15
+ use_input_image_size_as_output, max_input_image_size, randomize_seed):
16
+ input_images = [img1, img2, img3]
17
+ # Delete None
18
+ input_images = [img for img in input_images if img is not None]
19
+ if len(input_images) == 0:
20
+ input_images = None
21
+
22
+ if randomize_seed:
23
+ seed = random.randint(0, 10000000)
24
+
25
+ output = pipe(
26
+ prompt=text,
27
+ input_images=input_images,
28
+ height=height,
29
+ width=width,
30
+ guidance_scale=guidance_scale,
31
+ img_guidance_scale=img_guidance_scale,
32
+ num_inference_steps=inference_steps,
33
+ separate_cfg_infer=separate_cfg_infer,
34
+ use_kv_cache=True,
35
+ offload_kv_cache=True,
36
+ offload_model=offload_model,
37
+ use_input_image_size_as_output=use_input_image_size_as_output,
38
+ seed=seed,
39
+ max_input_image_size=max_input_image_size,
40
+ )
41
+ img = output[0]
42
+ return img
43
+
44
+
45
+
46
+ def get_example():
47
+ case = [
48
+ [
49
+ "A curly-haired man in a red shirt is drinking tea.",
50
+ None,
51
+ None,
52
+ None,
53
+ 1024,
54
+ 1024,
55
+ 2.5,
56
+ 1.6,
57
+ 0,
58
+ 1024,
59
+ False,
60
+ False,
61
+ ],
62
+ [
63
+ "The woman in <img><|image_1|></img> waves her hand happily in the crowd",
64
+ "./imgs/test_cases/zhang.png",
65
+ None,
66
+ None,
67
+ 1024,
68
+ 1024,
69
+ 2.5,
70
+ 1.9,
71
+ 128,
72
+ 1024,
73
+ False,
74
+ False,
75
+ ],
76
+ [
77
+ "A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
78
+ "./imgs/test_cases/two_man.jpg",
79
+ None,
80
+ None,
81
+ 1024,
82
+ 1024,
83
+ 2.5,
84
+ 1.6,
85
+ 0,
86
+ 1024,
87
+ False,
88
+ False,
89
+ ],
90
+ [
91
+ "Two woman are raising fried chicken legs in a bar. A woman is <img><|image_1|></img>. Another woman is <img><|image_2|></img>.",
92
+ "./imgs/test_cases/mckenna.jpg",
93
+ "./imgs/test_cases/Amanda.jpg",
94
+ None,
95
+ 1024,
96
+ 1024,
97
+ 2.5,
98
+ 1.8,
99
+ 65,
100
+ 1024,
101
+ False,
102
+ False,
103
+ ],
104
+ [
105
+ "A man and a short-haired woman with a wrinkled face are standing in front of a bookshelf in a library. The man is the man in the middle of <img><|image_1|></img>, and the woman is oldest woman in <img><|image_2|></img>",
106
+ "./imgs/test_cases/1.jpg",
107
+ "./imgs/test_cases/2.jpg",
108
+ None,
109
+ 1024,
110
+ 1024,
111
+ 2.5,
112
+ 1.6,
113
+ 60,
114
+ 1024,
115
+ False,
116
+ False,
117
+ ],
118
+ [
119
+ "A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <img><|image_1|></img>. The woman is the woman on the left of <img><|image_2|></img>",
120
+ "./imgs/test_cases/3.jpg",
121
+ "./imgs/test_cases/4.jpg",
122
+ None,
123
+ 1024,
124
+ 1024,
125
+ 2.5,
126
+ 1.8,
127
+ 66,
128
+ 1024,
129
+ False,
130
+ False,
131
+ ],
132
+ [
133
+ "The flower <img><|image_1|></img> is placed in the vase which is in the middle of <img><|image_2|></img> on a wooden table of a living room",
134
+ "./imgs/test_cases/rose.jpg",
135
+ "./imgs/test_cases/vase.jpg",
136
+ None,
137
+ 1024,
138
+ 1024,
139
+ 2.5,
140
+ 1.6,
141
+ 0,
142
+ 1024,
143
+ False,
144
+ False,
145
+ ],
146
+ [
147
+ "<img><|image_1|><img>\n Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola.",
148
+ "./imgs/demo_cases/t2i_woman_with_book.png",
149
+ None,
150
+ None,
151
+ None,
152
+ None,
153
+ 2.5,
154
+ 1.6,
155
+ 222,
156
+ 1024,
157
+ False,
158
+ True,
159
+ ],
160
+ [
161
+ "Detect the skeleton of human in this image: <img><|image_1|></img>.",
162
+ "./imgs/test_cases/control.jpg",
163
+ None,
164
+ None,
165
+ 1024,
166
+ 1024,
167
+ 2.0,
168
+ 1.6,
169
+ 0,
170
+ 1024,
171
+ False,
172
+ True,
173
+ ],
174
+ [
175
+ "Generate a new photo using the following picture and text as conditions: <img><|image_1|><img>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
176
+ "./imgs/demo_cases/skeletal.png",
177
+ None,
178
+ None,
179
+ 1024,
180
+ 1024,
181
+ 2,
182
+ 1.6,
183
+ 999,
184
+ 1024,
185
+ False,
186
+ True,
187
+ ],
188
+ [
189
+ "Following the pose of this image <img><|image_1|><img>, generate a new photo: A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
190
+ "./imgs/demo_cases/edit.png",
191
+ None,
192
+ None,
193
+ 1024,
194
+ 1024,
195
+ 2.0,
196
+ 1.6,
197
+ 123,
198
+ 1024,
199
+ False,
200
+ True,
201
+ ],
202
+ [
203
+ "Following the depth mapping of this image <img><|image_1|><img>, generate a new photo: A young girl is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
204
+ "./imgs/demo_cases/edit.png",
205
+ None,
206
+ None,
207
+ 1024,
208
+ 1024,
209
+ 2.0,
210
+ 1.6,
211
+ 1,
212
+ 1024,
213
+ False,
214
+ True,
215
+ ],
216
+ [
217
+ "<img><|image_1|><\/img> What item can be used to see the current time? Please highlight it in blue.",
218
+ "./imgs/test_cases/watch.jpg",
219
+ None,
220
+ None,
221
+ 1024,
222
+ 1024,
223
+ 2.5,
224
+ 1.6,
225
+ 666,
226
+ 1024,
227
+ False,
228
+ True,
229
+ ],
230
+ [
231
+ "According to the following examples, generate an output for the input.\nInput: <img><|image_1|></img>\nOutput: <img><|image_2|></img>\n\nInput: <img><|image_3|></img>\nOutput: ",
232
+ "./imgs/test_cases/icl1.jpg",
233
+ "./imgs/test_cases/icl2.jpg",
234
+ "./imgs/test_cases/icl3.jpg",
235
+ 224,
236
+ 224,
237
+ 2.5,
238
+ 1.6,
239
+ 1,
240
+ 768,
241
+ False,
242
+ False,
243
+ ],
244
+ ]
245
+ return case
246
+
247
+ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, seed, max_input_image_size, randomize_seed, use_input_image_size_as_output):
248
+ # 在函数内部设置默认值
249
+ inference_steps = 50
250
+ separate_cfg_infer = True
251
+ offload_model = False
252
+
253
+ return generate_image(
254
+ text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale,
255
+ inference_steps, seed, separate_cfg_infer, offload_model,
256
+ use_input_image_size_as_output, max_input_image_size, randomize_seed
257
+ )
258
+
259
+ description = """
260
+ OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, and image-conditioned generation.
261
+ For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>` (for the first image, the placeholder is <img><|image_1|></img>. for the second image, the the placeholder is <img><|image_2|></img>).
262
+ For example, use an image of a woman to generate a new image:
263
+ prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is \<img\>\<|image_1|\>\</img\>."
264
+ Tips:
265
+ - For image editing task and controlnet task, we recommend setting the height and width of output image as the same as input image. For example, if you want to edit a 512x512 image, you should set the height and width of output image as 512x512. You also can set the `use_input_image_size_as_output` to automatically set the height and width of output image as the same as input image.
266
+ - For out-of-memory or time cost, you can set `offload_model=True` or refer to [./docs/inference.md#requiremented-resources](https://github.com/VectorSpaceLab/OmniGen/blob/main/docs/inference.md#requiremented-resources) to select a appropriate setting.
267
+ - If inference time is too long when inputting multiple images, please try to reduce the `max_input_image_size`. For more details please refer to [./docs/inference.md#requiremented-resources](https://github.com/VectorSpaceLab/OmniGen/blob/main/docs/inference.md#requiremented-resources).
268
+ - Oversaturated: If the image appears oversaturated, please reduce the `guidance_scale`.
269
+ - Low-quality: More detailed prompts will lead to better results.
270
+ - Animate Style: If the generated images are in animate style, you can try to add `photo` to the prompt`.
271
+ - Edit generated image. If you generate an image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image.
272
+ - For image editing tasks, we recommend placing the image before the editing instruction. For example, use `<img><|image_1|></img> remove suit`, rather than `remove suit <img><|image_1|></img>`.
273
+
274
+
275
+ **HF Spaces often encounter errors due to quota limitations, so recommend to run it locally.**
276
+ """
277
+
278
+ article = """
279
+ ---
280
+ **Citation**
281
+ <br>
282
+ If you find this repository useful, please consider giving a star ⭐ and a citation
283
+ ```
284
+ @article{xiao2024omnigen,
285
+ title={Omnigen: Unified image generation},
286
+ author={Xiao, Shitao and Wang, Yueze and Zhou, Junjie and Yuan, Huaying and Xing, Xingrun and Yan, Ruiran and Wang, Shuting and Huang, Tiejun and Liu, Zheng},
287
+ journal={arXiv preprint arXiv:2409.11340},
288
+ year={2024}
289
+ }
290
+ ```
291
+ **Contact**
292
+ <br>
293
+ If you have any questions, please feel free to open an issue or directly reach us out via email.
294
+ """
295
+
296
+
297
+ # Gradio
298
+ with gr.Blocks() as demo:
299
+ gr.Markdown("# OmniGen: Unified Image Generation [paper](https://arxiv.org/abs/2409.11340) [code](https://github.com/VectorSpaceLab/OmniGen)")
300
+ gr.Markdown(description)
301
+ with gr.Row():
302
+ with gr.Column():
303
+ # text prompt
304
+ prompt_input = gr.Textbox(
305
+ label="Enter your prompt, use <img><|image_i|></img> to represent i-th input image", placeholder="Type your prompt here..."
306
+ )
307
+
308
+ with gr.Row(equal_height=True):
309
+ # input images
310
+ image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath")
311
+ image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath")
312
+ image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath")
313
+
314
+ # slider
315
+ height_input = gr.Slider(
316
+ label="Height", minimum=128, maximum=2048, value=1024, step=16
317
+ )
318
+ width_input = gr.Slider(
319
+ label="Width", minimum=128, maximum=2048, value=1024, step=16
320
+ )
321
+
322
+ guidance_scale_input = gr.Slider(
323
+ label="Guidance Scale", minimum=1.0, maximum=5.0, value=2.5, step=0.1
324
+ )
325
+
326
+ img_guidance_scale_input = gr.Slider(
327
+ label="img_guidance_scale", minimum=1.0, maximum=2.0, value=1.6, step=0.1
328
+ )
329
+
330
+ num_inference_steps = gr.Slider(
331
+ label="Inference Steps", minimum=1, maximum=100, value=50, step=1
332
+ )
333
+
334
+ seed_input = gr.Slider(
335
+ label="Seed", minimum=0, maximum=2147483647, value=42, step=1
336
+ )
337
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
338
+
339
+ max_input_image_size = gr.Slider(
340
+ label="max_input_image_size", minimum=128, maximum=2048, value=1024, step=16
341
+ )
342
+
343
+ separate_cfg_infer = gr.Checkbox(
344
+ label="separate_cfg_infer", info="Whether to use separate inference process for different guidance. This will reduce the memory cost.", value=True,
345
+ )
346
+ offload_model = gr.Checkbox(
347
+ label="offload_model", info="Offload model to CPU, which will significantly reduce the memory cost but slow down the generation speed. You can cancel separate_cfg_infer and set offload_model=True. If both separate_cfg_infer and offload_model are True, further reduce the memory, but slowest generation", value=False,
348
+ )
349
+ use_input_image_size_as_output = gr.Checkbox(
350
+ label="use_input_image_size_as_output", info="Automatically adjust the output image size to be same as input image size. For editing and controlnet task, it can make sure the output image has the same size as input image leading to better performance", value=False,
351
+ )
352
+
353
+ # generate
354
+ generate_button = gr.Button("Generate Image")
355
+
356
+
357
+ with gr.Column():
358
+ # output image
359
+ output_image = gr.Image(label="Output Image")
360
+
361
+ # click
362
+ generate_button.click(
363
+ generate_image,
364
+ inputs=[
365
+ prompt_input,
366
+ image_input_1,
367
+ image_input_2,
368
+ image_input_3,
369
+ height_input,
370
+ width_input,
371
+ guidance_scale_input,
372
+ img_guidance_scale_input,
373
+ num_inference_steps,
374
+ seed_input,
375
+ separate_cfg_infer,
376
+ offload_model,
377
+ use_input_image_size_as_output,
378
+ max_input_image_size,
379
+ randomize_seed,
380
+ ],
381
+ outputs=output_image,
382
+ )
383
+
384
+ gr.Examples(
385
+ examples=get_example(),
386
+ fn=run_for_examples,
387
+ inputs=[
388
+ prompt_input,
389
+ image_input_1,
390
+ image_input_2,
391
+ image_input_3,
392
+ height_input,
393
+ width_input,
394
+ guidance_scale_input,
395
+ img_guidance_scale_input,
396
+ seed_input,
397
+ max_input_image_size,
398
+ randomize_seed,
399
+ use_input_image_size_as_output,
400
+ ],
401
+ outputs=output_image,
402
+ )
403
+
404
+ gr.Markdown(article)
405
+
406
+ # launch
407
+ demo.launch()
408
+
OmniGen/docs/fine-tuning.md ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning OmniGen
2
+
3
+ Fine-tuning Omnigen can better help you handle specific image generation tasks. For example, by fine-tuning on a person's images, you can generate multiple pictures of that person while maintaining task consistency.
4
+
5
+ A lot of previous work focused on designing new networks to facilitate specific tasks. For instance, ControlNet was proposed to handle image conditions, and IP-Adapter was constructed to maintain ID features. If you want to perform new tasks, you need to build new architectures and repeatedly debug them. Adding and adjusting extra network parameters is usually time-consuming and labor-intensive, which is not user-friendly and cost-efficient enough. However, with Omnigen, all of this becomes very simple.
6
+
7
+ By comparison, Omnigen can accept multi-modal conditional inputs and has been pre-trained on various tasks. You can fine-tune it on any task without designing specialized networks like ControlNet or IP-Adapter for a specific task.
8
+
9
+ **All you need to do is prepare the data and start training. You can break the limitations of previous models, allowing Omnigen to accomplish a variety of interesting tasks, even those that have never been done before.**
10
+
11
+
12
+ ## Installation
13
+
14
+ ```bash
15
+ git clone https://github.com/VectorSpaceLab/OmniGen.git
16
+ cd OmniGen
17
+ pip install -e .
18
+ ```
19
+
20
+
21
+ ## Full fine-tuning
22
+
23
+ ### Fine-tuning command
24
+
25
+ ```bash
26
+ accelerate launch \
27
+ --num_processes=1 \
28
+ --use_fsdp \
29
+ --fsdp_offload_params false \
30
+ --fsdp_sharding_strategy SHARD_GRAD_OP \
31
+ --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP \
32
+ --fsdp_transformer_layer_cls_to_wrap Phi3DecoderLayer \
33
+ --fsdp_state_dict_type FULL_STATE_DICT \
34
+ --fsdp_forward_prefetch false \
35
+ --fsdp_use_orig_params True \
36
+ --fsdp_cpu_ram_efficient_loading false \
37
+ --fsdp_sync_module_states True \
38
+ train.py \
39
+ --model_name_or_path Shitao/OmniGen-v1 \
40
+ --json_file ./toy_data/toy_data.jsonl \
41
+ --image_path ./toy_data/images \
42
+ --batch_size_per_device 1 \
43
+ --lr 2e-5 \
44
+ --keep_raw_resolution \
45
+ --max_image_size 1024 \
46
+ --gradient_accumulation_steps 1 \
47
+ --ckpt_every 100 \
48
+ --epochs 100 \
49
+ --log_every 1 \
50
+ --results_dir ./results/toy_finetune
51
+ ```
52
+
53
+ Some important arguments:
54
+ - `num_processes`: number of GPU to use for training
55
+ - `model_name_or_path`: path to the pretrained model
56
+ - `json_file`: path to the json file containing the training data, e.g., ./toy_data/toy_data.jsonl
57
+ - `image_path`: path to the image folder, e.g., ./toy_data/images
58
+ - `batch_size_per_device`: batch size per device
59
+ - `lr`: learning rate
60
+ - `keep_raw_resolution`: whether to keep the original resolution of the image, if not, all images will be resized to (max_image_size, max_image_size)
61
+ - `max_image_size`: max image size
62
+ - `gradient_accumulation_steps`: number of steps to accumulate gradients
63
+ - `ckpt_every`: number of steps to save checkpoint
64
+ - `epochs`: number of epochs
65
+ - `log_every`: number of steps to log
66
+ - `results_dir`: path to the results folder
67
+
68
+ The data format of json_file is as follows:
69
+ ```
70
+ {
71
+ "instruction": str,
72
+ "input_images": [str, str, ...],
73
+ "output_images": str
74
+ }
75
+ ```
76
+ You can see a toy example in `./toy_data/toy_data.jsonl`.
77
+
78
+ If an OOM(Out of Memory) issue occurs, you can try to decrease the `batch_size_per_device` or `max_image_size`. You can also try to use LoRA instead of full fine-tuning.
79
+
80
+
81
+ ### Inference
82
+
83
+ The checkpoint can be found at `{results_dir}/checkpoints/*`. You can use the following command to load saved checkpoint:
84
+ ```python
85
+ from OmniGen import OmniGenPipeline
86
+
87
+ pipe = OmniGenPipeline.from_pretrained("checkpoint_path") # e.g., ./results/toy_finetune/checkpoints/0000200
88
+ ```
89
+
90
+
91
+
92
+
93
+
94
+ ## LoRA fine-tuning
95
+ LoRA fine-tuning is a simple way to fine-tune OmniGen with less GPU memory. To use lora, you should add `--use_lora` and `--lora_rank` to the command.
96
+
97
+ ```bash
98
+ accelerate launch \
99
+ --num_processes=1 \
100
+ train.py \
101
+ --model_name_or_path Shitao/OmniGen-v1 \
102
+ --batch_size_per_device 2 \
103
+ --condition_dropout_prob 0.01 \
104
+ --lr 3e-4 \
105
+ --use_lora \
106
+ --lora_rank 8 \
107
+ --json_file ./toy_data/toy_data.jsonl \
108
+ --image_path ./toy_data/images \
109
+ --max_input_length_limit 18000 \
110
+ --keep_raw_resolution \
111
+ --max_image_size 1024 \
112
+ --gradient_accumulation_steps 1 \
113
+ --ckpt_every 100 \
114
+ --epochs 100 \
115
+ --log_every 1 \
116
+ --results_dir ./results/toy_finetune_lora
117
+ ```
118
+
119
+ ### Inference
120
+
121
+ The checkpoint can be found at `{results_dir}/checkpoints/*`. You can use the following command to load checkpoint:
122
+ ```python
123
+ from OmniGen import OmniGenPipeline
124
+
125
+ pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
126
+ pipe.merge_lora("checkpoint_path") # e.g., ./results/toy_finetune_lora/checkpoints/0000100
127
+ ```
128
+
129
+
130
+ ## A simple example
131
+
132
+ Here is an example for learning new concepts: "sks dog". We use five images of one dog from [dog-example](https://huggingface.co/datasets/diffusers/dog-example).
133
+
134
+ The json file is `./toy_data/toy_subject_data.jsonl`, and the images have been saved in `./toy_data/images`.
135
+
136
+ ```bash
137
+ accelerate launch \
138
+ --num_processes=1 \
139
+ train.py \
140
+ --model_name_or_path Shitao/OmniGen-v1 \
141
+ --batch_size_per_device 2 \
142
+ --condition_dropout_prob 0.01 \
143
+ --lr 1e-3 \
144
+ --use_lora \
145
+ --lora_rank 8 \
146
+ --json_file ./toy_data/toy_subject_data.jsonl \
147
+ --image_path ./toy_data/images \
148
+ --max_input_length_limit 18000 \
149
+ --keep_raw_resolution \
150
+ --max_image_size 1024 \
151
+ --gradient_accumulation_steps 1 \
152
+ --ckpt_every 100 \
153
+ --epochs 200 \
154
+ --log_every 1 \
155
+ --results_dir ./results/toy_finetune_lora
156
+ ```
157
+
158
+ After training, you can use the following command to generate images:
159
+ ```python
160
+ from OmniGen import OmniGenPipeline
161
+
162
+ pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
163
+ pipe.merge_lora("checkpoint_path") # e.g., ./results/toy_finetune_lora/checkpoints/0000200
164
+
165
+ images = pipe(
166
+ prompt="a photo of sks dog running in the snow",
167
+ height=1024,
168
+ width=1024,
169
+ guidance_scale=3
170
+ )
171
+ images[0].save("example_sks_dog_snow.png")
172
+ ```
OmniGen/docs/inference.md ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference with OmniGen
2
+
3
+ To handle some complex tasks, image generation models are becoming increasingly sophisticated, leading to more and more cumbersome workflows. Existing image generation models like SD and Flux require loading many additional network modules (such as ControlNet, IP-Adapter, Reference-Net) and extra preprocessing steps (e.g., face detection, pose detection, image cropping) to generate a satisfactory image. This complex workflow is not user-friendly. We believe that future image generation models should be simpler, generating various images directly through instructions, similar to how GPT works in language generation.
4
+
5
+ Therefore, we propose OmniGen, a model capable of handling various image generation tasks within a single framework. The goal of OmniGen is to complete various image generation tasks without relying on any additional components or image preprocessing steps. OmniGen supports tasks including text-to-image generation, image editing, subject-driven image generation, and classical vision tasks, among others. More capabilities can be found in our examples. We provide inference code so you can explore more unknown functionalities yourself.
6
+
7
+
8
+
9
+ ## Install
10
+ ```bash
11
+ git clone https://github.com/staoxiao/OmniGen.git
12
+ cd OmniGen
13
+ pip install -e .
14
+ ```
15
+
16
+
17
+
18
+ ## Generate Images
19
+ You can use the following code to generate images:
20
+ ```python
21
+ from OmniGen import OmniGenPipeline
22
+
23
+ pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
24
+
25
+ # Text to Image
26
+ images = pipe(
27
+ prompt="A curly-haired man in a red shirt is drinking tea.",
28
+ height=1024,
29
+ width=1024,
30
+ guidance_scale=2.5,
31
+ seed=0,
32
+ )
33
+ images[0].save("example_t2i.png") # save output PIL Image
34
+
35
+ # Multi-modal to Image
36
+ # In prompt, we use the placeholder to represent the image. The image placeholder should be in the format of <img><|image_*|></img>
37
+ # You can add multiple images in the input_images. Please ensure that each image has its placeholder. For example, for the list input_images [img1_path, img2_path], the prompt needs to have two placeholders: <img><|image_1|></img>, <img><|image_2|></img>.
38
+ images = pipe(
39
+ prompt="A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
40
+ input_images=["./imgs/test_cases/two_man.jpg"],
41
+ height=1024,
42
+ width=1024,
43
+ guidance_scale=2.5,
44
+ img_guidance_scale=1.6,
45
+ max_input_image_size=1024,
46
+ separate_cfg_infer=True,
47
+ use_kv_cache=True,
48
+ offload_kv_cache=True,
49
+ offload_model=False,
50
+ use_input_image_size_as_output=False,
51
+ seed=0,
52
+ )
53
+ images[0].save("example_ti2i.png") # save output PIL image
54
+ ```
55
+
56
+ Some important arguments:
57
+ - `guidance_scale`: The strength of the guidance. Based on our experience, it is usually best to set it between 2 and 3. The higher the value, the more similar the generated image will be to the prompt. If the image appears oversaturated, please reduce the scale.
58
+ - `height` and `width`: The height and width of the generated image. The default value is 1024x1024. OmniGen support any size, but these number must be divisible by 16.
59
+ - `num_inference_steps`: The number of steps to take in the diffusion process. The higher the value, the more detailed the generated image will be.
60
+ - `max_input_image_size`: the maximum size of input image, which will be used to crop the input image to the maximum size. A smaller number will result in faster generation speed and lower memory cost.
61
+ - `separate_cfg_infer`: Whether to use separate inference process for CFG guidance. If set to True, memory cost will be lower. Default is True.
62
+ - `use_kv_cache`: Whether to use key-value cache. Default is True.
63
+ - `offload_kv_cache`: offload the cached key and value to cpu, which can save memory but slow down the generation silightly. Default is True.
64
+ - `offload_model`: offload the model to cpu, which can save memory but slow down the generation. Default is False.
65
+ - `use_input_image_size_as_output`: whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task. Default is False.
66
+ - `seed`: The seed for random number generator.
67
+
68
+ **More examples please refer to [inference.ipynb](../inference.ipynb)**
69
+
70
+
71
+ #### Input data
72
+ OmniGen can accept multi-modal input data. Specifically, you should pass two arguments: `prompt` and `input_images`.
73
+ For text to image generation, you can pass a string as `prompt`, or pass a list of strings as `prompt` to generate multiple images.
74
+
75
+ For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>`.
76
+ For example, if you want to generate an image with a person holding a bouquet of flowers, you can pass the following prompt:
77
+ ```
78
+ prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is <img><|image_1|></img>."
79
+ input_images = ["./imgs/test_cases/liuyifei.png"]
80
+ ```
81
+ The placeholder `<|image_1|>` will be replaced by the image at `input_images[0]`, i.e., `./imgs/test_cases/liuyifei.png`.
82
+
83
+ If you want to generate multiple images, you can pass a list of prompts and a list of image paths. For example:
84
+ ```
85
+ prompt = ["A woman holds a bouquet of flowers and faces the camera.", "A woman holds a bouquet of flowers and faces the camera. Thw woman is <img><|image_1|></img>."]
86
+ input_images = [[], ["./imgs/test_cases/liuyifei.png"]]
87
+ ```
88
+
89
+
90
+ #### Gradio Demo
91
+ We have constructed a online demo in [Huggingface](https://huggingface.co/spaces/Shitao/OmniGen).
92
+
93
+ For the local gradio demo, you can run with the following command:
94
+ ```python
95
+ python app.py
96
+ ```
97
+
98
+
99
+ ## Tips
100
+ - For out of memory or time cost, you can refer to [./docs/inference.md#requiremented-resources](https://github.com/VectorSpaceLab/OmniGen/blob/main/docs/inference.md#requiremented-resources) to select a appropriate setting.
101
+ - Oversaturated: If the image appears oversaturated, please reduce the `guidance_scale`.
102
+ - Not match the prompt: If the image does not match the prompt, please try to increase the `guidance_scale`.
103
+ - Low-quality: More detailed prompt will lead to better results.
104
+ - Animate Style: If the genereate images is in animate style, you can try to add `photo` to the prompt`.
105
+ - Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image.
106
+ - For image editing tasks, we recommend placing the image before the editing instruction. For example, use `<img><|image_1|></img> remove suit`, rather than `remove suit <img><|image_1|></img>`.
107
+ - For image editing task and controlnet task, we recommend to set the height and width of output image as the same
108
+ as input image. For example, if you want to edit a 512x512 image, you should set the height and width of output image as 512x512. You also can set the `use_input_image_size_as_output` to automatically set the height and width of output image as the same as input image.
109
+
110
+
111
+ ## Requiremented Resources
112
+
113
+ We are currently experimenting with some techniques to reduce memory usage and improve speed, including `use_kv_cache, offload_kv_cache, separate_cfg_infer, offload_model`, which you can enable in the pipeline.
114
+ The default setting is`use_kv_cache=True, offload_kv_cache=True, separate_cfg_infer=True, offload_model=False`.
115
+ To reduce memory consumption while maintaining inference speed, quantization is also a method worth exploring and is left for future work.
116
+
117
+ We conducted experiments on the A800 and RTX 3090. The memory requirements and inference times are shown in the table below. You can choose the appropriate settings based on your available resources.
118
+
119
+ **Overall, the text-to-image task requires minimal memory and time costs, comparable to other latest text-to-image models. However, when using input images, the computational cost increases. Memory usage can be reduced by extending the processing time.**
120
+
121
+
122
+ - Different image size.
123
+
124
+ Different image size (`max_input_image_size` is the max size of input image, `height` and `width` are the size of output image) with the default inference settings (`use_kv_cache=True,offload_kv_cache=True,separate_cfg_infer=True`)
125
+
126
+ For A800 GPU:
127
+ | Settings | Only Text | Text + Single Image | Text + Two Images |
128
+ |:-------------|:----------:|:-------------------:|:---------------------:|
129
+ | max_input_image_size=1024,height=1024,width=1024 | 9G, 31s | 12G, 1m6s | 13G, 1m20s |
130
+ | max_input_image_size=512,height=1024,width=1024 | 9G, 31s | 10G, 50s | 10G, 54s |
131
+ | max_input_image_size=768,height=768,width=768 | 9G, 16s | 10G, 32s | 10G, 37s |
132
+ | max_input_image_size=512,height=512,width=512 | 9G, 7s | 9G, 14s | 9G, 15s |
133
+
134
+ For RTX 3090 GPU(24G):
135
+ | Settings | Only Text | Text + Single Image | Text + Two Images |
136
+ |:-------------|:----------:|:-------------------:|:---------------------:|
137
+ | max_input_image_size=1024,height=1024,width=1024 | 9G, 1m17s | 12G, 2m46s | 13G, 3m23s |
138
+ | max_input_image_size=512,height=1024,width=1024 | 9G, 1m18s | 10G, 2m8s | 10G, 2m18s |
139
+ | max_input_image_size=768,height=768,width=768 | 9G, 41s | 10G, 1m22s | 10G, 1m38s |
140
+ | max_input_image_size=512,height=512,width=512 | 9G, 19s | 9G, 36s | 9G, 43s |
141
+
142
+
143
+ You can set smaller `max_input_image_size` to reduce memory usage, but note that the generation quality may be lower.
144
+ And please set the `height` and `width` the same as the size of input image for image editing task.
145
+
146
+
147
+ - Different inference settings
148
+
149
+ Default image size: height=1024, width=1024, max_input_image_size=1024
150
+
151
+ For A800 GPU:
152
+ | Settings | Only Text | Text + Single Image | Text + Two Images |
153
+ |:-------------|:----------:|:-------------------:|:---------------------:|
154
+ | use_kv_cache | 18G, 30s | 36G, 1m | 48G, 1m13s |
155
+ | use_kv_cache,offload_kv_cache | 10G, 30s | 14G, 1m10s | 17G, 1m30s |
156
+ | use_kv_cache,offload_kv_cache,separate_cfg_infer | 9G, 31s | 12G, 1m6s | 13G, 1m20s |
157
+ | use_kv_cache,offload_kv_cache,offload_model | 4G, 55s | 7G, 1m30s | 11G, 1m48s |
158
+ | use_kv_cache,offload_kv_cache,separate_cfg_infer,offload_model | 3G, 1m23s | 5G, 2m19s | 6G, 2m30s |
159
+
160
+ For RTX 3090 GPU(24G):
161
+ | Settings | Only Text | Text + Single Image | Text + Two Images |
162
+ |:-------------|:----------:|:-------------------:|:---------------------:|
163
+ | use_kv_cache | 18G, 1m14s | OOM | OOM |
164
+ | use_kv_cache,offload_kv_cache | 10G, 1m17s | 14G, 3m11s | 17G, 4m3s |
165
+ | use_kv_cache,offload_kv_cache,separate_cfg_infer | 9G, 1m18s | 12G, 2m46s | 13G, 3m21s |
166
+ | use_kv_cache,offload_kv_cache,offload_model | 4G,3m1s | 7G, 4m14s | 11G, 5m4s |
167
+ | use_kv_cache,offload_kv_cache,separate_cfg_infer,offload_model | 3G, 4m56s | 5G, 7m49s | 6G, 8m6s |
OmniGen/imgs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
OmniGen/imgs/demo_cases.png ADDED

Git LFS Details

  • SHA256: 0517c97c947f8226f0f39b4ca2ac61b058e52faa59ec5085668062d0162dd21e
  • Pointer size: 132 Bytes
  • Size of remote file: 3.42 MB
OmniGen/imgs/demo_cases/AI_Pioneers.jpg ADDED

Git LFS Details

  • SHA256: 0b7f51ae11a11781027d1f9e1e8d566438f937508e16c627bfb60acca5b1d7c0
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
OmniGen/imgs/demo_cases/edit.png ADDED

Git LFS Details

  • SHA256: 2fac5461b2c06a99664ba1299fd9fcebd781a26afa5ebc07aa07cb678ebae2af
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
OmniGen/imgs/demo_cases/entity.png ADDED

Git LFS Details

  • SHA256: 7c622ebecd3210c80e8d913158ee3564168c77c576f04b56e34d2d28bfea9e06
  • Pointer size: 132 Bytes
  • Size of remote file: 1.28 MB
OmniGen/imgs/demo_cases/reasoning.png ADDED

Git LFS Details

  • SHA256: eb510edcb5628c0def3871cef2e0351acc578a1ceef445ebbd72f8b6eb92fc9d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
OmniGen/imgs/demo_cases/same_pose.png ADDED

Git LFS Details

  • SHA256: beccbeabfc408f319661d9af1063005cbc21c977ba50b910491611ca3babd876
  • Pointer size: 132 Bytes
  • Size of remote file: 1.36 MB
OmniGen/imgs/demo_cases/skeletal.png ADDED

Git LFS Details

  • SHA256: 30c7937855228adec69da7d9bc3170c9f434a6b159feaf02d362033c1901a671
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
OmniGen/imgs/demo_cases/skeletal2img.png ADDED

Git LFS Details

  • SHA256: 86c21341018bb633f364d40afbf361b5e5690bf1e6539b99150e4aea0ed695b6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.28 MB
OmniGen/imgs/demo_cases/t2i_woman_with_book.png ADDED

Git LFS Details

  • SHA256: fe258160193adeaff960a838de01d7f7294ab09899de534f2dee99043b0c747a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
OmniGen/imgs/overall.jpg ADDED

Git LFS Details

  • SHA256: ffa229632ac0bb248eee87cf823a0dc18c22c0a81a57d4c639e7fb1986d4e029
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
OmniGen/imgs/referring.png ADDED

Git LFS Details

  • SHA256: 393fab6a4d51e84555f75162430e35a64a49670d9e6c3986cd80bca318a4fb3e
  • Pointer size: 132 Bytes
  • Size of remote file: 4.09 MB
OmniGen/imgs/test_cases/1.jpg ADDED

Git LFS Details

  • SHA256: d2dad7a81a5c609d136fbcccc2a71007c20474103d301ae5564fa63258b4a492
  • Pointer size: 132 Bytes
  • Size of remote file: 1.87 MB
OmniGen/imgs/test_cases/2.jpg ADDED

Git LFS Details

  • SHA256: 919ec1a20515ce921d04a5a0f6dcbe5aa4288f41c04cc62a0bd59103957b45db
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
OmniGen/imgs/test_cases/3.jpg ADDED

Git LFS Details

  • SHA256: c8fef6b304efc3fc189991ec28b83bbe15c391af55b2bfd85276eb19d49194c9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.2 MB
OmniGen/imgs/test_cases/4.jpg ADDED

Git LFS Details

  • SHA256: 222e844198656a13facbf0f0afe327b074641a7f20d4120418fa1302e61db538
  • Pointer size: 132 Bytes
  • Size of remote file: 1.74 MB
OmniGen/imgs/test_cases/Amanda.jpg ADDED

Git LFS Details

  • SHA256: c20a508b8619fca4d963f574bca51c7460f274218507c97c2853fa6eaea6d0cb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.65 MB
OmniGen/imgs/test_cases/cat.jpeg ADDED

Git LFS Details

  • SHA256: 076d8d520f68d1cf2f6a66366721e03ad20cf9d385839f5c26b3d2060bb6d789
  • Pointer size: 129 Bytes
  • Size of remote file: 6.11 kB
OmniGen/imgs/test_cases/control.jpg ADDED

Git LFS Details

  • SHA256: 5ca485995cb5f4b1b792e39a99e9647745291f92689eb40f1da925f19dfdc1b5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.32 MB
OmniGen/imgs/test_cases/guitar1.png ADDED

Git LFS Details

  • SHA256: d1c405aaaaa2e6660157aca7d908a4374cd5ca78585f32202b323c24d4fab756
  • Pointer size: 131 Bytes
  • Size of remote file: 836 kB
OmniGen/imgs/test_cases/icl1.jpg ADDED

Git LFS Details

  • SHA256: 0e2b2086ad903c43aee1cc98902b7c53864765c6c99758acf39618ac1ad54b0e
  • Pointer size: 130 Bytes
  • Size of remote file: 76.6 kB
OmniGen/imgs/test_cases/icl2.jpg ADDED

Git LFS Details

  • SHA256: 48bafc52d6721c636e1aec9ebcd1a76c017cc926909bf03270993dd423bc49f9
  • Pointer size: 130 Bytes
  • Size of remote file: 86.3 kB
OmniGen/imgs/test_cases/icl3.jpg ADDED

Git LFS Details

  • SHA256: 077a7c69f7ca24808922e5acc7762a62182d51031f4f9e0d035ce80b09a81d5e
  • Pointer size: 130 Bytes
  • Size of remote file: 77.8 kB
OmniGen/imgs/test_cases/img1.jpg ADDED

Git LFS Details

  • SHA256: f0036a7c89f60de3366afc6fffedf52be91cf91da36b2fb9b1722e641debfeb7
  • Pointer size: 131 Bytes
  • Size of remote file: 176 kB
OmniGen/imgs/test_cases/img2.jpg ADDED

Git LFS Details

  • SHA256: 999845fb0e781cc8b9db0e9407ad6d4a528dc08677caec3fcc9f6ad6afc718bc
  • Pointer size: 131 Bytes
  • Size of remote file: 175 kB