AlexGraikos commited on
Commit
2730c93
·
verified ·
1 Parent(s): 69c7767

Delete embeddings_pixcell.py

Browse files
Files changed (1) hide show
  1. embeddings_pixcell.py +0 -230
embeddings_pixcell.py DELETED
@@ -1,230 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import math
15
- from typing import List, Optional, Tuple, Union
16
-
17
- import numpy as np
18
- import torch
19
- import torch.nn.functional as F
20
- from torch import nn
21
-
22
- from diffusers.models.activations import deprecate, FP32SiLU
23
-
24
-
25
- def pixcell_get_2d_sincos_pos_embed(
26
- embed_dim,
27
- grid_size,
28
- cls_token=False,
29
- extra_tokens=0,
30
- interpolation_scale=1.0,
31
- base_size=16,
32
- device: Optional[torch.device] = None,
33
- phase=0,
34
- output_type: str = "np",
35
- ):
36
- """
37
- Creates 2D sinusoidal positional embeddings.
38
-
39
- Args:
40
- embed_dim (`int`):
41
- The embedding dimension.
42
- grid_size (`int`):
43
- The size of the grid height and width.
44
- cls_token (`bool`, defaults to `False`):
45
- Whether or not to add a classification token.
46
- extra_tokens (`int`, defaults to `0`):
47
- The number of extra tokens to add.
48
- interpolation_scale (`float`, defaults to `1.0`):
49
- The scale of the interpolation.
50
-
51
- Returns:
52
- pos_embed (`torch.Tensor`):
53
- Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
54
- embed_dim]` if using cls_token
55
- """
56
- if output_type == "np":
57
- deprecation_message = (
58
- "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
59
- " `from_numpy` is no longer required."
60
- " Pass `output_type='pt' to use the new version now."
61
- )
62
- deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
63
- raise ValueError("Not supported")
64
- if isinstance(grid_size, int):
65
- grid_size = (grid_size, grid_size)
66
-
67
- grid_h = (
68
- torch.arange(grid_size[0], device=device, dtype=torch.float32)
69
- / (grid_size[0] / base_size)
70
- / interpolation_scale
71
- )
72
- grid_w = (
73
- torch.arange(grid_size[1], device=device, dtype=torch.float32)
74
- / (grid_size[1] / base_size)
75
- / interpolation_scale
76
- )
77
- grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
78
- grid = torch.stack(grid, dim=0)
79
-
80
- grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
81
- pos_embed = pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=phase, output_type=output_type)
82
- if cls_token and extra_tokens > 0:
83
- pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
84
- return pos_embed
85
-
86
-
87
- def pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=0, output_type="np"):
88
- r"""
89
- This function generates 2D sinusoidal positional embeddings from a grid.
90
-
91
- Args:
92
- embed_dim (`int`): The embedding dimension.
93
- grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
94
-
95
- Returns:
96
- `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
97
- """
98
- if output_type == "np":
99
- deprecation_message = (
100
- "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
101
- " `from_numpy` is no longer required."
102
- " Pass `output_type='pt' to use the new version now."
103
- )
104
- deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
105
- raise ValueError("Not supported")
106
- if embed_dim % 2 != 0:
107
- raise ValueError("embed_dim must be divisible by 2")
108
-
109
- # use half of dimensions to encode grid_h
110
- emb_h = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], phase=phase, output_type=output_type) # (H*W, D/2)
111
- emb_w = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], phase=phase, output_type=output_type) # (H*W, D/2)
112
-
113
- emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
114
- return emb
115
-
116
-
117
- def pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim, pos, phase=0, output_type="np"):
118
- """
119
- This function generates 1D positional embeddings from a grid.
120
-
121
- Args:
122
- embed_dim (`int`): The embedding dimension `D`
123
- pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
124
-
125
- Returns:
126
- `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
127
- """
128
- if output_type == "np":
129
- deprecation_message = (
130
- "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
131
- " `from_numpy` is no longer required."
132
- " Pass `output_type='pt' to use the new version now."
133
- )
134
- deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
135
- raise ValueError("Not supported")
136
- if embed_dim % 2 != 0:
137
- raise ValueError("embed_dim must be divisible by 2")
138
-
139
- omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
140
- omega /= embed_dim / 2.0
141
- omega = 1.0 / 10000**omega # (D/2,)
142
-
143
- pos = pos.reshape(-1) + phase # (M,)
144
- out = torch.outer(pos, omega) # (M, D/2), outer product
145
-
146
- emb_sin = torch.sin(out) # (M, D/2)
147
- emb_cos = torch.cos(out) # (M, D/2)
148
-
149
- emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
150
- return emb
151
-
152
-
153
- class PixcellUNIProjection(nn.Module):
154
- """
155
- Projects UNI embeddings. Also handles dropout for classifier-free guidance.
156
-
157
- Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
158
- """
159
-
160
- def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", num_tokens=1):
161
- super().__init__()
162
- if out_features is None:
163
- out_features = hidden_size
164
- self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
165
- if act_fn == "gelu_tanh":
166
- self.act_1 = nn.GELU(approximate="tanh")
167
- elif act_fn == "silu":
168
- self.act_1 = nn.SiLU()
169
- elif act_fn == "silu_fp32":
170
- self.act_1 = FP32SiLU()
171
- else:
172
- raise ValueError(f"Unknown activation function: {act_fn}")
173
- self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
174
-
175
- self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features ** 0.5))
176
-
177
- def forward(self, caption):
178
- hidden_states = self.linear_1(caption)
179
- hidden_states = self.act_1(hidden_states)
180
- hidden_states = self.linear_2(hidden_states)
181
- return hidden_states
182
-
183
- class UNIPosEmbed(nn.Module):
184
- """
185
- Adds positional embeddings to the UNI conditions.
186
-
187
- Args:
188
- height (`int`, defaults to `224`): The height of the image.
189
- width (`int`, defaults to `224`): The width of the image.
190
- patch_size (`int`, defaults to `16`): The size of the patches.
191
- in_channels (`int`, defaults to `3`): The number of input channels.
192
- embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
193
- layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
194
- flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
195
- bias (`bool`, defaults to `True`): Whether or not to use bias.
196
- interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
197
- pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
198
- pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
199
- """
200
-
201
- def __init__(
202
- self,
203
- height=1,
204
- width=1,
205
- base_size=16,
206
- embed_dim=768,
207
- interpolation_scale=1,
208
- pos_embed_type="sincos",
209
- ):
210
- super().__init__()
211
-
212
- num_embeds = height*width
213
- grid_size = int(num_embeds ** 0.5)
214
-
215
- if pos_embed_type == "sincos":
216
- y_pos_embed = pixcell_get_2d_sincos_pos_embed(
217
- embed_dim,
218
- grid_size,
219
- base_size=base_size,
220
- interpolation_scale=interpolation_scale,
221
- output_type="pt",
222
- phase = base_size // num_embeds
223
- )
224
- self.register_buffer("y_pos_embed", y_pos_embed.float().unsqueeze(0))
225
- else:
226
- raise ValueError("`pos_embed_type` not supported")
227
-
228
- def forward(self, uni_embeds):
229
- return (uni_embeds + self.y_pos_embed).to(uni_embeds.dtype)
230
-