hpoghos commited on
Commit
3760daa
·
1 Parent(s): d67a615
requirements.txt CHANGED
@@ -28,7 +28,7 @@ torchvision==0.15.1
28
  modelscope==1.13.3
29
  tqdm==4.65.0
30
  xformers==0.0.19
31
- open-clip-torch==2.24.0
32
  jsonargparse[signatures]==4.27.7
33
  fairscale==0.4.13
34
  rotary-embedding-torch==0.5.3
 
28
  modelscope==1.13.3
29
  tqdm==4.65.0
30
  xformers==0.0.19
31
+ # open-clip-torch==2.24.0
32
  jsonargparse[signatures]==4.27.7
33
  fairscale==0.4.13
34
  rotary-embedding-torch==0.5.3
t2v_enhanced/model/diffusers_conditional/models/controlnet/image_embedder.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any, Mapping
3
  import torch
4
  import torch.nn as nn
5
  import kornia
6
- import open_clip
7
  from transformers import AutoImageProcessor, AutoModel
8
  from transformers.models.bit.image_processing_bit import BitImageProcessor
9
  from einops import rearrange, repeat
@@ -52,160 +52,160 @@ class AbstractEncoder(nn.Module):
52
 
53
 
54
 
55
- class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
56
- """
57
- Uses the OpenCLIP vision transformer encoder for images
58
- """
59
-
60
- def __init__(
61
- self,
62
- arch="ViT-H-14",
63
- version="laion2b_s32b_b79k",
64
- device="cuda",
65
- max_length=77,
66
- freeze=True,
67
- antialias=True,
68
- ucg_rate=0.0,
69
- unsqueeze_dim=False,
70
- repeat_to_max_len=False,
71
- num_image_crops=0,
72
- output_tokens=False,
73
- ):
74
- super().__init__()
75
- model, _, _ = open_clip.create_model_and_transforms(
76
- arch,
77
- device=torch.device("cpu"),
78
- pretrained=version,
79
- )
80
- del model.transformer
81
- self.model = model
82
- self.max_crops = num_image_crops
83
- self.pad_to_max_len = self.max_crops > 0
84
- self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
85
- self.device = device
86
- self.max_length = max_length
87
- if freeze:
88
- self.freeze()
89
-
90
- self.antialias = antialias
91
-
92
- self.register_buffer(
93
- "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
94
- )
95
- self.register_buffer(
96
- "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
97
- )
98
- self.ucg_rate = ucg_rate
99
- self.unsqueeze_dim = unsqueeze_dim
100
- self.stored_batch = None
101
- self.model.visual.output_tokens = output_tokens
102
- self.output_tokens = output_tokens
103
-
104
- def preprocess(self, x):
105
- # normalize to [0,1]
106
- x = kornia.geometry.resize(
107
- x,
108
- (224, 224),
109
- interpolation="bicubic",
110
- align_corners=True,
111
- antialias=self.antialias,
112
- )
113
- x = (x + 1.0) / 2.0
114
- # renormalize according to clip
115
- x = kornia.enhance.normalize(x, self.mean, self.std)
116
- return x
117
-
118
- def freeze(self):
119
- self.model = self.model.eval()
120
- for param in self.parameters():
121
- param.requires_grad = False
122
-
123
- def forward(self, image, no_dropout=False):
124
- z = self.encode_with_vision_transformer(image)
125
- tokens = None
126
- if self.output_tokens:
127
- z, tokens = z[0], z[1]
128
- z = z.to(image.dtype)
129
- if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
130
- z = (
131
- torch.bernoulli(
132
- (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
133
- )[:, None]
134
- * z
135
- )
136
- if tokens is not None:
137
- tokens = (
138
- expand_dims_like(
139
- torch.bernoulli(
140
- (1.0 - self.ucg_rate)
141
- * torch.ones(tokens.shape[0], device=tokens.device)
142
- ),
143
- tokens,
144
- )
145
- * tokens
146
- )
147
- if self.unsqueeze_dim:
148
- z = z[:, None, :]
149
- if self.output_tokens:
150
- assert not self.repeat_to_max_len
151
- assert not self.pad_to_max_len
152
- return tokens, z
153
- if self.repeat_to_max_len:
154
- if z.dim() == 2:
155
- z_ = z[:, None, :]
156
- else:
157
- z_ = z
158
- return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
159
- elif self.pad_to_max_len:
160
- assert z.dim() == 3
161
- z_pad = torch.cat(
162
- (
163
- z,
164
- torch.zeros(
165
- z.shape[0],
166
- self.max_length - z.shape[1],
167
- z.shape[2],
168
- device=z.device,
169
- ),
170
- ),
171
- 1,
172
- )
173
- return z_pad, z_pad[:, 0, ...]
174
- return z
175
-
176
- def encode_with_vision_transformer(self, img):
177
- # if self.max_crops > 0:
178
- # img = self.preprocess_by_cropping(img)
179
- if img.dim() == 5:
180
- assert self.max_crops == img.shape[1]
181
- img = rearrange(img, "b n c h w -> (b n) c h w")
182
- img = self.preprocess(img)
183
- if not self.output_tokens:
184
- assert not self.model.visual.output_tokens
185
- x = self.model.visual(img)
186
- tokens = None
187
- else:
188
- assert self.model.visual.output_tokens
189
- x, tokens = self.model.visual(img)
190
- if self.max_crops > 0:
191
- x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
192
- # drop out between 0 and all along the sequence axis
193
- x = (
194
- torch.bernoulli(
195
- (1.0 - self.ucg_rate)
196
- * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
197
- )
198
- * x
199
- )
200
- if tokens is not None:
201
- tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
202
- print(
203
- f"You are running very experimental token-concat in {self.__class__.__name__}. "
204
- f"Check what you are doing, and then remove this message."
205
- )
206
- if self.output_tokens:
207
- return x, tokens
208
- return x
209
-
210
- def encode(self, text):
211
- return self(text)
 
3
  import torch
4
  import torch.nn as nn
5
  import kornia
6
+ # import open_clip
7
  from transformers import AutoImageProcessor, AutoModel
8
  from transformers.models.bit.image_processing_bit import BitImageProcessor
9
  from einops import rearrange, repeat
 
52
 
53
 
54
 
55
+ # class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
56
+ # """
57
+ # Uses the OpenCLIP vision transformer encoder for images
58
+ # """
59
+
60
+ # def __init__(
61
+ # self,
62
+ # arch="ViT-H-14",
63
+ # version="laion2b_s32b_b79k",
64
+ # device="cuda",
65
+ # max_length=77,
66
+ # freeze=True,
67
+ # antialias=True,
68
+ # ucg_rate=0.0,
69
+ # unsqueeze_dim=False,
70
+ # repeat_to_max_len=False,
71
+ # num_image_crops=0,
72
+ # output_tokens=False,
73
+ # ):
74
+ # super().__init__()
75
+ # model, _, _ = open_clip.create_model_and_transforms(
76
+ # arch,
77
+ # device=torch.device("cpu"),
78
+ # pretrained=version,
79
+ # )
80
+ # del model.transformer
81
+ # self.model = model
82
+ # self.max_crops = num_image_crops
83
+ # self.pad_to_max_len = self.max_crops > 0
84
+ # self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
85
+ # self.device = device
86
+ # self.max_length = max_length
87
+ # if freeze:
88
+ # self.freeze()
89
+
90
+ # self.antialias = antialias
91
+
92
+ # self.register_buffer(
93
+ # "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
94
+ # )
95
+ # self.register_buffer(
96
+ # "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
97
+ # )
98
+ # self.ucg_rate = ucg_rate
99
+ # self.unsqueeze_dim = unsqueeze_dim
100
+ # self.stored_batch = None
101
+ # self.model.visual.output_tokens = output_tokens
102
+ # self.output_tokens = output_tokens
103
+
104
+ # def preprocess(self, x):
105
+ # # normalize to [0,1]
106
+ # x = kornia.geometry.resize(
107
+ # x,
108
+ # (224, 224),
109
+ # interpolation="bicubic",
110
+ # align_corners=True,
111
+ # antialias=self.antialias,
112
+ # )
113
+ # x = (x + 1.0) / 2.0
114
+ # # renormalize according to clip
115
+ # x = kornia.enhance.normalize(x, self.mean, self.std)
116
+ # return x
117
+
118
+ # def freeze(self):
119
+ # self.model = self.model.eval()
120
+ # for param in self.parameters():
121
+ # param.requires_grad = False
122
+
123
+ # def forward(self, image, no_dropout=False):
124
+ # z = self.encode_with_vision_transformer(image)
125
+ # tokens = None
126
+ # if self.output_tokens:
127
+ # z, tokens = z[0], z[1]
128
+ # z = z.to(image.dtype)
129
+ # if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
130
+ # z = (
131
+ # torch.bernoulli(
132
+ # (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
133
+ # )[:, None]
134
+ # * z
135
+ # )
136
+ # if tokens is not None:
137
+ # tokens = (
138
+ # expand_dims_like(
139
+ # torch.bernoulli(
140
+ # (1.0 - self.ucg_rate)
141
+ # * torch.ones(tokens.shape[0], device=tokens.device)
142
+ # ),
143
+ # tokens,
144
+ # )
145
+ # * tokens
146
+ # )
147
+ # if self.unsqueeze_dim:
148
+ # z = z[:, None, :]
149
+ # if self.output_tokens:
150
+ # assert not self.repeat_to_max_len
151
+ # assert not self.pad_to_max_len
152
+ # return tokens, z
153
+ # if self.repeat_to_max_len:
154
+ # if z.dim() == 2:
155
+ # z_ = z[:, None, :]
156
+ # else:
157
+ # z_ = z
158
+ # return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
159
+ # elif self.pad_to_max_len:
160
+ # assert z.dim() == 3
161
+ # z_pad = torch.cat(
162
+ # (
163
+ # z,
164
+ # torch.zeros(
165
+ # z.shape[0],
166
+ # self.max_length - z.shape[1],
167
+ # z.shape[2],
168
+ # device=z.device,
169
+ # ),
170
+ # ),
171
+ # 1,
172
+ # )
173
+ # return z_pad, z_pad[:, 0, ...]
174
+ # return z
175
+
176
+ # def encode_with_vision_transformer(self, img):
177
+ # # if self.max_crops > 0:
178
+ # # img = self.preprocess_by_cropping(img)
179
+ # if img.dim() == 5:
180
+ # assert self.max_crops == img.shape[1]
181
+ # img = rearrange(img, "b n c h w -> (b n) c h w")
182
+ # img = self.preprocess(img)
183
+ # if not self.output_tokens:
184
+ # assert not self.model.visual.output_tokens
185
+ # x = self.model.visual(img)
186
+ # tokens = None
187
+ # else:
188
+ # assert self.model.visual.output_tokens
189
+ # x, tokens = self.model.visual(img)
190
+ # if self.max_crops > 0:
191
+ # x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
192
+ # # drop out between 0 and all along the sequence axis
193
+ # x = (
194
+ # torch.bernoulli(
195
+ # (1.0 - self.ucg_rate)
196
+ # * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
197
+ # )
198
+ # * x
199
+ # )
200
+ # if tokens is not None:
201
+ # tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
202
+ # print(
203
+ # f"You are running very experimental token-concat in {self.__class__.__name__}. "
204
+ # f"Check what you are doing, and then remove this message."
205
+ # )
206
+ # if self.output_tokens:
207
+ # return x, tokens
208
+ # return x
209
+
210
+ # def encode(self, text):
211
+ # return self(text)