BiliSakura commited on
Commit
2d5c526
·
verified ·
1 Parent(s): 98d2417

Delete transformer

Browse files
transformer/__pycache__/transformer_mvsplit_dit.cpython-312.pyc DELETED
Binary file (21.4 kB)
 
transformer/config.json DELETED
@@ -1,20 +0,0 @@
1
- {
2
- "_class_name": "MVSplitDiTTransformer2DModel",
3
- "_diffusers_version": "0.38.0",
4
- "context_dim": 1024,
5
- "depth": 1000,
6
- "hidden_size": 1024,
7
- "in_channels": 128,
8
- "init_alpha": 0.0,
9
- "init_beta": 0.03,
10
- "mlp_hidden_dim": 3072,
11
- "norm_eps": 1e-05,
12
- "num_heads": 8,
13
- "num_kv_heads": 8,
14
- "patch_size": 1,
15
- "qkv_bias": false,
16
- "rope_base": 10000,
17
- "trainable_rms": true,
18
- "use_rope": true,
19
- "torch_dtype": "bfloat16"
20
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformer/diffusion_pytorch_model-00001-of-00006.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5ebd66315a82685b17dcd82724bd8cb91c5d92af4cec794ab2afa94ac48c0038
3
- size 4998288504
 
 
 
 
transformer/diffusion_pytorch_model-00002-of-00006.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b19bf5b84b48ae73e88c039809a63eb60d3f3cb74a541abe0fcba71d387e3839
3
- size 4993827600
 
 
 
 
transformer/diffusion_pytorch_model-00003-of-00006.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3d3b16f617d0934d373015d9097661d37abb46c382f747eb006e1070d28bbdbb
3
- size 4991729616
 
 
 
 
transformer/diffusion_pytorch_model-00004-of-00006.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:95883e73ca3680ba3ecbe9cdb88ea1d7794fb384ccffa47a9646d2e8e4bbef76
3
- size 4991729616
 
 
 
 
transformer/diffusion_pytorch_model-00005-of-00006.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:abb4700307e188f5cfd71b4fc2c1319d22ecee354c1ad56267cc61796a2d0fbe
3
- size 4991729616
 
 
 
 
transformer/diffusion_pytorch_model-00006-of-00006.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bf5e9d723915a3db3e6f84e181c96c1237378f4f6f1aff2230ca542cdf42a5af
3
- size 2310435160
 
 
 
 
transformer/diffusion_pytorch_model.safetensors.index.json DELETED
The diff for this file is too large to render. See raw diff
 
transformer/transformer_mvsplit_dit.py DELETED
@@ -1,350 +0,0 @@
1
- from dataclasses import dataclass
2
- import math
3
- from typing import Optional, Tuple, Union
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- from torch import nn
8
- from diffusers.models.activations import SwiGLU
9
- from diffusers.models.embeddings import PatchEmbed, apply_rotary_emb
10
- from diffusers.models.normalization import RMSNorm
11
-
12
- try:
13
- from diffusers.configuration_utils import ConfigMixin, register_to_config
14
- from diffusers.models.modeling_utils import ModelMixin
15
- from diffusers.utils import BaseOutput
16
- except Exception:
17
- class BaseOutput(dict):
18
- def __post_init__(self):
19
- self.update(self.__dict__)
20
-
21
- class _Config(dict):
22
- def __getattr__(self, key):
23
- try:
24
- return self[key]
25
- except KeyError as error:
26
- raise AttributeError(key) from error
27
-
28
- class ConfigMixin:
29
- config_name = "config.json"
30
-
31
- class ModelMixin(nn.Module):
32
- pass
33
-
34
- def register_to_config(init):
35
- def wrapper(self, *args, **kwargs):
36
- import inspect
37
-
38
- signature = inspect.signature(init)
39
- bound = signature.bind(self, *args, **kwargs)
40
- bound.apply_defaults()
41
- self.config = _Config({key: value for key, value in bound.arguments.items() if key != "self"})
42
- init(self, *args, **kwargs)
43
-
44
- return wrapper
45
-
46
-
47
- @dataclass
48
- class MVSplitDiTTransformer2DModelOutput(BaseOutput):
49
- sample: torch.FloatTensor
50
-
51
-
52
- class TwoDimRotary(nn.Module):
53
- def __init__(self, dim: int, base: int = 10000):
54
- super().__init__()
55
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, dtype=torch.float32) / max(dim, 1)))
56
- self.register_buffer("inv_freq", inv_freq, persistent=False)
57
-
58
- def forward(
59
- self,
60
- height: int,
61
- width: int,
62
- device: torch.device,
63
- dtype: torch.dtype,
64
- ) -> Tuple[torch.Tensor, torch.Tensor]:
65
- pos_h = torch.arange(height, device=device, dtype=self.inv_freq.dtype)
66
- pos_w = torch.arange(width, device=device, dtype=self.inv_freq.dtype)
67
- freqs_h = torch.outer(pos_h, self.inv_freq).unsqueeze(1).repeat(1, width, 1)
68
- freqs_w = torch.outer(pos_w, self.inv_freq).unsqueeze(0).repeat(height, 1, 1)
69
- freqs = torch.cat([freqs_h, freqs_w], dim=-1).reshape(height * width, -1)
70
- cos = freqs.cos().to(dtype=dtype)
71
- sin = freqs.sin().to(dtype=dtype)
72
- return cos, sin
73
-
74
-
75
- class QKNorm(nn.Module):
76
- def __init__(self, dim: int, eps: float = 1e-6, trainable: bool = False):
77
- super().__init__()
78
- self.query_norm = RMSNorm(dim, eps=eps, elementwise_affine=trainable)
79
- self.key_norm = RMSNorm(dim, eps=eps, elementwise_affine=trainable)
80
-
81
- def forward(self, query: torch.Tensor, key: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
82
- return self.query_norm(query), self.key_norm(key)
83
-
84
-
85
- class FusedMVSplitNorm1(nn.Module):
86
- def __init__(self, dim: int, eps: float = 1e-5, init_alpha: float = 0.0, init_beta: float = 0.03):
87
- super().__init__()
88
- self.eps = eps
89
- self.alpha = nn.Parameter(torch.full((dim,), init_alpha))
90
- self.beta = nn.Parameter(torch.full((dim,), init_beta))
91
- self.weight = nn.Parameter(torch.ones(dim))
92
-
93
- def _rms_norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
94
- original_dtype = hidden_states.dtype
95
- hidden_states = hidden_states.float()
96
- hidden_states = hidden_states * torch.rsqrt(hidden_states.pow(2).mean(dim=-1, keepdim=True) + self.eps)
97
- hidden_states = hidden_states * self.weight.float()
98
- return hidden_states.to(dtype=original_dtype)
99
-
100
- def forward(
101
- self,
102
- residual: torch.Tensor,
103
- update: torch.Tensor,
104
- l_image_tokens: Optional[int] = None,
105
- ) -> torch.Tensor:
106
- if l_image_tokens is not None and 0 < l_image_tokens < residual.shape[1]:
107
- residual_img, residual_txt = residual[:, :l_image_tokens], residual[:, l_image_tokens:]
108
- update_img, update_txt = update[:, :l_image_tokens], update[:, l_image_tokens:]
109
-
110
- residual_img_mean = residual_img.mean(dim=1, keepdim=True)
111
- residual_txt_mean = residual_txt.mean(dim=1, keepdim=True)
112
- update_img_mean = update_img.mean(dim=1, keepdim=True)
113
- update_txt_mean = update_txt.mean(dim=1, keepdim=True)
114
-
115
- update_img_var = update_img - update_img_mean
116
- update_txt_var = update_txt - update_txt_mean
117
-
118
- alpha = self.alpha.view(1, 1, -1)
119
- beta = self.beta.view(1, 1, -1)
120
- var_update = torch.cat([update_img_var * beta, update_txt_var * beta], dim=1)
121
- mean_update = torch.cat(
122
- [
123
- (alpha * (update_img_mean - residual_img_mean)).expand_as(residual_img),
124
- (alpha * (update_txt_mean - residual_txt_mean)).expand_as(residual_txt),
125
- ],
126
- dim=1,
127
- )
128
- else:
129
- residual_mean = residual.mean(dim=1, keepdim=True)
130
- update_mean = update.mean(dim=1, keepdim=True)
131
- var_update = self.beta * (update - update_mean)
132
- mean_update = self.alpha * (update_mean - residual_mean).expand_as(residual)
133
-
134
- return self._rms_norm(residual + var_update + mean_update)
135
-
136
-
137
- class Attention(nn.Module):
138
- def __init__(
139
- self,
140
- dim: int,
141
- num_heads: int,
142
- num_kv_heads: int,
143
- qkv_bias: bool,
144
- trainable_rms: bool,
145
- ):
146
- super().__init__()
147
- if dim % num_heads != 0:
148
- raise ValueError("dim must be divisible by num_heads.")
149
-
150
- self.num_heads = num_heads
151
- self.num_kv_heads = num_kv_heads
152
- self.head_dim = dim // num_heads
153
- if self.num_heads % self.num_kv_heads != 0:
154
- raise ValueError("num_heads must be divisible by num_kv_heads.")
155
- self.num_groups = self.num_heads // self.num_kv_heads
156
- kv_dim = self.num_kv_heads * self.head_dim
157
-
158
- self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
159
- self.k_proj = nn.Linear(dim, kv_dim, bias=qkv_bias)
160
- self.v_proj = nn.Linear(dim, kv_dim, bias=qkv_bias)
161
- self.proj = nn.Linear(dim, dim, bias=False)
162
- self.qk_norm = QKNorm(self.head_dim, trainable=trainable_rms)
163
- self.scale = 1.0 / math.sqrt(self.head_dim)
164
-
165
- def forward(self, hidden_states: torch.Tensor, rope: Optional[Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor:
166
- batch_size, _, _ = hidden_states.shape
167
- query = self.q_proj(hidden_states).reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
168
- key = self.k_proj(hidden_states).reshape(batch_size, -1, self.num_kv_heads, self.head_dim).transpose(1, 2)
169
- value = self.v_proj(hidden_states).reshape(batch_size, -1, self.num_kv_heads, self.head_dim).transpose(1, 2)
170
-
171
- if rope is not None:
172
- query = apply_rotary_emb(query, rope)
173
- key = apply_rotary_emb(key, rope)
174
- query, key = self.qk_norm(query, key)
175
-
176
- if self.num_groups > 1:
177
- key = torch.repeat_interleave(key, self.num_groups, dim=1)
178
- value = torch.repeat_interleave(value, self.num_groups, dim=1)
179
-
180
- hidden_states = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
181
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
182
- return self.proj(hidden_states)
183
-
184
-
185
- class DiTBlock(nn.Module):
186
- def __init__(
187
- self,
188
- hidden_size: int,
189
- num_heads: int,
190
- num_kv_heads: int,
191
- mlp_hidden_dim: int,
192
- qkv_bias: bool,
193
- trainable_rms: bool,
194
- norm_eps: float,
195
- init_alpha: float,
196
- init_beta: float,
197
- ):
198
- super().__init__()
199
- self.attn = Attention(hidden_size, num_heads, num_kv_heads, qkv_bias=qkv_bias, trainable_rms=trainable_rms)
200
- self.ffn = nn.Sequential(
201
- SwiGLU(hidden_size, mlp_hidden_dim, bias=qkv_bias),
202
- nn.Linear(mlp_hidden_dim, hidden_size, bias=qkv_bias),
203
- )
204
- self.norm1 = FusedMVSplitNorm1(hidden_size, eps=norm_eps, init_alpha=init_alpha, init_beta=init_beta)
205
- self.norm2 = FusedMVSplitNorm1(hidden_size, eps=norm_eps, init_alpha=init_alpha, init_beta=init_beta)
206
-
207
- def forward(
208
- self,
209
- hidden_states: torch.Tensor,
210
- rope: Optional[Tuple[torch.Tensor, torch.Tensor]],
211
- l_image_tokens: Optional[int],
212
- ) -> torch.Tensor:
213
- residual = hidden_states
214
- hidden_states = self.attn(hidden_states, rope=rope)
215
- hidden_states = self.norm1(residual, hidden_states, l_image_tokens=l_image_tokens)
216
-
217
- residual = hidden_states
218
- hidden_states = self.ffn(hidden_states)
219
- hidden_states = self.norm2(residual, hidden_states, l_image_tokens=l_image_tokens)
220
- return hidden_states
221
-
222
-
223
- class MVSplitDiTTransformer2DModel(ModelMixin, ConfigMixin):
224
- config_name = "config.json"
225
-
226
- @register_to_config
227
- def __init__(
228
- self,
229
- in_channels: int = 128,
230
- patch_size: int = 1,
231
- hidden_size: int = 1024,
232
- depth: int = 1000,
233
- num_heads: int = 8,
234
- num_kv_heads: int = 8,
235
- mlp_hidden_dim: int = 3072,
236
- context_dim: int = 1024,
237
- qkv_bias: bool = False,
238
- trainable_rms: bool = False,
239
- use_rope: bool = True,
240
- rope_base: int = 10000,
241
- norm_eps: float = 1e-5,
242
- init_alpha: float = 0.0,
243
- init_beta: float = 0.03,
244
- ):
245
- super().__init__()
246
- self.in_channels = in_channels
247
- self.out_channels = in_channels
248
- self.patch_size = patch_size
249
- self.hidden_size = hidden_size
250
- self.use_rope = use_rope
251
- self.rope_dim = hidden_size // (2 * num_heads)
252
-
253
- self.patch_embed = PatchEmbed(
254
- height=1,
255
- width=1,
256
- patch_size=patch_size,
257
- in_channels=in_channels,
258
- embed_dim=hidden_size,
259
- layer_norm=False,
260
- flatten=True,
261
- bias=True,
262
- pos_embed_type=None,
263
- )
264
- self.norm_img_input = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=trainable_rms)
265
- self.norm_text_input = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=trainable_rms)
266
- self.context_proj = nn.Identity() if context_dim == hidden_size else nn.Linear(context_dim, hidden_size, bias=False)
267
- self.rope = TwoDimRotary(self.rope_dim, base=rope_base) if use_rope else None
268
-
269
- self.blocks = nn.ModuleList(
270
- [
271
- DiTBlock(
272
- hidden_size=hidden_size,
273
- num_heads=num_heads,
274
- num_kv_heads=num_kv_heads,
275
- mlp_hidden_dim=mlp_hidden_dim,
276
- qkv_bias=qkv_bias,
277
- trainable_rms=trainable_rms,
278
- norm_eps=norm_eps,
279
- init_alpha=init_alpha,
280
- init_beta=init_beta,
281
- )
282
- for _ in range(depth)
283
- ]
284
- )
285
- self.final_proj = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
286
-
287
- def _unpatchify(
288
- self,
289
- hidden_states: torch.Tensor,
290
- batch_size: int,
291
- height_tokens: int,
292
- width_tokens: int,
293
- ) -> torch.Tensor:
294
- patch = self.patch_size
295
- hidden_states = hidden_states.reshape(
296
- batch_size, height_tokens, width_tokens, patch, patch, self.out_channels
297
- )
298
- hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4).reshape(
299
- batch_size, self.out_channels, height_tokens * patch, width_tokens * patch
300
- )
301
- return hidden_states
302
-
303
- def forward(
304
- self,
305
- hidden_states: torch.Tensor,
306
- encoder_hidden_states: torch.Tensor,
307
- timestep: Optional[Union[torch.Tensor, float]] = None,
308
- return_dict: bool = True,
309
- ) -> Union[MVSplitDiTTransformer2DModelOutput, Tuple[torch.Tensor]]:
310
- del timestep
311
- if hidden_states.ndim != 4:
312
- raise ValueError("hidden_states must have shape [B, C, H, W].")
313
- if encoder_hidden_states.ndim != 3:
314
- raise ValueError("encoder_hidden_states must have shape [B, L_text, context_dim].")
315
-
316
- batch_size, channels, height, width = hidden_states.shape
317
- if channels != self.in_channels:
318
- raise ValueError(f"Expected {self.in_channels} latent channels, got {channels}.")
319
- if height % self.patch_size != 0 or width % self.patch_size != 0:
320
- raise ValueError("Latent height and width must be divisible by patch_size.")
321
-
322
- height_tokens = height // self.patch_size
323
- width_tokens = width // self.patch_size
324
- image_tokens = self.norm_img_input(self.patch_embed(hidden_states))
325
- l_image_tokens = image_tokens.shape[1]
326
-
327
- text_tokens = self.norm_text_input(self.context_proj(encoder_hidden_states))
328
- sequence = torch.cat([image_tokens, text_tokens], dim=1)
329
-
330
- rope = None
331
- if self.use_rope and self.rope is not None:
332
- cos_image, sin_image = self.rope(height_tokens, width_tokens, sequence.device, sequence.dtype)
333
- text_length = text_tokens.shape[1]
334
- rope_width = cos_image.shape[-1]
335
- if text_length > 0:
336
- cos_text = torch.ones((text_length, rope_width), device=sequence.device, dtype=sequence.dtype)
337
- sin_text = torch.zeros((text_length, rope_width), device=sequence.device, dtype=sequence.dtype)
338
- rope = (torch.cat([cos_image, cos_text], dim=0), torch.cat([sin_image, sin_text], dim=0))
339
- else:
340
- rope = (cos_image, sin_image)
341
-
342
- for block in self.blocks:
343
- sequence = block(sequence, rope=rope, l_image_tokens=l_image_tokens)
344
-
345
- sequence = self.final_proj(sequence[:, :l_image_tokens, :])
346
- sequence = self._unpatchify(sequence, batch_size=batch_size, height_tokens=height_tokens, width_tokens=width_tokens)
347
-
348
- if not return_dict:
349
- return (sequence,)
350
- return MVSplitDiTTransformer2DModelOutput(sample=sequence)