Ruian7P commited on
Commit
edc183f
·
verified ·
1 Parent(s): 70c7632

upload head_t5_large_2e-5_ech5

Browse files
head_t5_large_2e-5_ech5/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Emuru"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_imuru.ImuruConfig",
7
+ "AutoModel": "modeling_imuru.Imuru"
8
+ },
9
+ "model_type": "emuru",
10
+ "slices_per_query": 1,
11
+ "style_enc": "full",
12
+ "t5_name_or_path": "google-t5/t5-large",
13
+ "tokenizer_name_or_path": "google/byt5-small",
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.54.0",
16
+ "vae_channels": 1,
17
+ "vae_name_or_path": "blowing-up-groundhogs/emuru_vae"
18
+ }
head_t5_large_2e-5_ech5/configuration_imuru.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class ImuruConfig(PretrainedConfig):
4
+ model_type = "emuru"
5
+
6
+ def __init__(self,
7
+ t5_name_or_path='google-t5/t5-large',
8
+ vae_name_or_path='blowing-up-groundhogs/emuru_vae',
9
+ tokenizer_name_or_path='google/byt5-small',
10
+ slices_per_query=1,
11
+ vae_channels=1,
12
+ style_enc="mean",
13
+ **kwargs):
14
+ super().__init__(**kwargs)
15
+ self.t5_name_or_path = t5_name_or_path
16
+ self.vae_name_or_path = vae_name_or_path
17
+ self.tokenizer_name_or_path = tokenizer_name_or_path
18
+ self.slices_per_query = slices_per_query
19
+ self.vae_channels = vae_channels
20
+ self.style_enc = style_enc
head_t5_large_2e-5_ech5/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:598a7712b9eb1f008fef799f7e26b8f5baab4a56e870641b087657968e45a66e
3
+ size 2876698944
head_t5_large_2e-5_ech5/modeling_imuru.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, T5ForConditionalGeneration, T5Config, AutoTokenizer
4
+ from .configuration_imuru import ImuruConfig
5
+ from diffusers import AutoencoderKL
6
+ from einops.layers.torch import Rearrange
7
+ from einops import repeat
8
+ from torchvision.transforms import functional as F
9
+ from typing import Optional, Tuple, List, Any
10
+ from PIL import Image
11
+
12
+ class Imuru(PreTrainedModel):
13
+ """
14
+ Imuru is a conditional generative model that integrates a T5-based decoder with a VAE
15
+ for image generation conditioned on text and style images.
16
+ Attributes:
17
+ config_class (Type): Configuration class for the model.
18
+ tokenizer (AutoTokenizer): Tokenizer loaded from the provided tokenizer configuration.
19
+ T5 (T5ForConditionalGeneration): T5 model adapted for conditional generation.
20
+ sos (nn.Embedding): Start-of-sequence embedding.
21
+ vae_to_t5 (nn.Linear): Linear projection from VAE latent space to T5 hidden space.
22
+ t5_to_vae (nn.Linear): Linear projection from T5 hidden space back to VAE latent space.
23
+ padding_token (nn.Parameter): Non-trainable parameter for padding tokens.
24
+ padding_token_threshold (nn.Parameter): Non-trainable parameter for padding token threshold.
25
+ vae (AutoencoderKL): Pre-trained Variational Autoencoder.
26
+ query_rearrange (Rearrange): Layer to rearrange VAE latent representations for queries.
27
+ z_rearrange (Rearrange): Layer to rearrange T5 outputs back to VAE latent dimensions.
28
+ mse_criterion (nn.MSELoss): Mean squared error loss function.
29
+ """
30
+ config_class = ImuruConfig
31
+
32
+ def __init__(self, config: ImuruConfig) -> None:
33
+ """
34
+ Initialize the Imuru model.
35
+ Args:
36
+ config (ImuruConfig): Configuration object containing model hyperparameters and paths.
37
+ """
38
+ super().__init__(config)
39
+
40
+ self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
41
+
42
+ t5_config = T5Config.from_pretrained(config.t5_name_or_path)
43
+ t5_config.vocab_size = len(self.tokenizer)
44
+ self.T5 = T5ForConditionalGeneration(t5_config)
45
+ self.T5.lm_head = nn.Identity()
46
+ self.sos = nn.Embedding(1, t5_config.d_model)
47
+
48
+ vae_latent_size = 8 * config.vae_channels * config.slices_per_query
49
+ self.vae_to_t5 = nn.Linear(vae_latent_size, t5_config.d_model)
50
+ self.t5_to_vae = nn.Linear(t5_config.d_model, vae_latent_size, bias=False)
51
+
52
+ self.padding_token = nn.Parameter( torch.tensor([[-0.4951, 0.8021, 0.3429, 0.5622, 0.5271, 0.5756, 0.7194, 0.6150]]), requires_grad=False)
53
+ self.padding_token_threshold = nn.Parameter(torch.tensor(0.484982096850872), requires_grad=False)
54
+
55
+ self.query_rearrange = Rearrange('b c h (w q) -> b w (q c h)', q=config.slices_per_query)
56
+ self.z_rearrange = Rearrange('b w (q c h) -> b c h (w q)', c=config.vae_channels, q=config.slices_per_query)
57
+
58
+ self.style_enc = config.style_enc if hasattr(config, 'style_enc') else "mean"
59
+ print(f"Using style encoder: {self.style_enc}")
60
+ if self.style_enc == "MLP": # w -> 1
61
+ self.style_encoder = nn.Linear(vae_latent_size, 1)
62
+ elif self.style_enc == "MLP2":
63
+ self.style_encoder = nn.Sequential(
64
+ nn.Linear(vae_latent_size, vae_latent_size),
65
+ nn.SiLU(),
66
+ nn.Linear(vae_latent_size, 1)
67
+ )
68
+
69
+ self.mse_criterion = nn.MSELoss()
70
+ self.init_weights()
71
+
72
+ self.vae = AutoencoderKL.from_pretrained(config.vae_name_or_path)
73
+ self.set_training(self.vae, False)
74
+
75
+ def set_training(self, model: nn.Module, training: bool) -> None:
76
+ """
77
+ Set the training mode for a given model and freeze/unfreeze parameters accordingly.
78
+ Args:
79
+ model (nn.Module): The model to set the training mode for.
80
+ training (bool): If True, set the model to training mode; otherwise, evaluation mode.
81
+ """
82
+ model.train() if training else model.eval()
83
+ for param in model.parameters():
84
+ param.requires_grad = training
85
+
86
+ def forward(
87
+ self,
88
+ img: Optional[torch.Tensor] = None,
89
+ label_img: Optional[torch.Tensor] = None,
90
+ input_ids: Optional[torch.Tensor] = None,
91
+ attention_mask: Optional[torch.Tensor] = None,
92
+ style_noise: float = 0,
93
+ label_noise: float = 0,
94
+ **kwargs: Any
95
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
96
+ """
97
+ Forward pass of the model.
98
+ Args:
99
+ img (Optional[torch.Tensor]): Input Style image tensor.
100
+ input_ids (Optional[torch.Tensor]): Tokenized input IDs.
101
+ attention_mask (Optional[torch.Tensor]): Attention mask for the inputs.
102
+ noise (float): Amount of noise to add in image encoding.
103
+ **kwargs: Additional arguments.
104
+ Returns:
105
+ Tuple containing:
106
+ - mse_loss (torch.Tensor): Mean squared error loss.
107
+ - pred_latent (torch.Tensor): Predicted latent representations.
108
+ - z (torch.Tensor): Sampled latent vector from VAE.
109
+ """
110
+ assert label_img is not None, 'label_img must be provided during training'
111
+
112
+ posterior_style = self.vae.encode(img.float())
113
+ z_style = posterior_style.latent_dist.sample()
114
+ z_style_sequence = self.query_rearrange(z_style)
115
+ if style_noise > 0:
116
+ z_style_sequence = z_style_sequence + torch.randn_like(z_style_sequence) * style_noise
117
+
118
+ if self.style_enc == "mean":
119
+ style_global = z_style_sequence.mean(dim=1, keepdim=True) # (b, 1, d)
120
+ elif self.style_enc in ["MLP", "MLP2"]:
121
+ style_scores = self.style_encoder(z_style_sequence) # (b, w, 1)
122
+ style_weights = torch.softmax(style_scores, dim=1) # (b, w, 1)
123
+ style_global = (z_style_sequence * style_weights).sum(dim=1, keepdim=True) # (b, 1, d)
124
+ elif self.style_enc == "full":
125
+ style_global = z_style_sequence # (b, w, d)
126
+ else:
127
+ raise ValueError(f"Unknown style_enc type: {self.style_enc}")
128
+
129
+ w_style = style_global.size(1)
130
+ style_token_embed = self.vae_to_t5(style_global) # (b, 1, t5_d_model)
131
+
132
+ posterior_label = self.vae.encode(label_img.float())
133
+ z_label = posterior_label.latent_dist.sample()
134
+ z_label_sequence = self.query_rearrange(z_label) # (b, w, d)
135
+
136
+ if label_noise > 0:
137
+ z_label_sequence_noisy = z_label_sequence + torch.randn_like(z_label_sequence) * label_noise
138
+ else:
139
+ z_label_sequence_noisy = z_label_sequence
140
+
141
+ sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
142
+
143
+ label_embeds = self.vae_to_t5(z_label_sequence_noisy) # (b, w, t5_d_model)
144
+ decoder_inputs_embeds = torch.cat(
145
+ [style_token_embed, sos, label_embeds[:, :-1]], dim=1
146
+ ) # (b, w_style + 1 + (w_label -1), t5_d_model)
147
+
148
+ output = self.T5(input_ids, attention_mask=attention_mask, decoder_inputs_embeds=decoder_inputs_embeds) # (b, w_style + 1 + (w_label -1), t5_d_model)
149
+ all_vae_latent = self.t5_to_vae(output.logits) # (b, w_style + 1 + (w_label -1), vae_latent_size)
150
+
151
+ vae_latent = all_vae_latent[:, w_style:, :] # (b, w_label, vae_latent_size)
152
+
153
+ min_seq_len = min(vae_latent.size(1), z_label_sequence.size(1))
154
+ vae_latent_trimmed = vae_latent[:, :min_seq_len]
155
+ z_label_sequence_trimmed = z_label_sequence[:, :min_seq_len]
156
+
157
+ mse_loss = self.mse_criterion(vae_latent_trimmed, z_label_sequence_trimmed)
158
+ pred_latent = self.z_rearrange(vae_latent_trimmed)
159
+
160
+ return mse_loss, pred_latent, z_label
161
+
162
+ @torch.inference_mode()
163
+ def generate(
164
+ self,
165
+ gen_text: str,
166
+ style_img: torch.Tensor,
167
+ **kwargs: Any
168
+ ) -> Image.Image:
169
+ """
170
+ Generate an image by combining style and generation texts with a style image.
171
+ Args:
172
+ style_text (str): Style-related text prompt.
173
+ gen_text (str): Generation-related text prompt.
174
+ style_img (torch.Tensor): Style image tensor. Expected shape is either 3D or 4D.
175
+ **kwargs: Additional keyword arguments.
176
+ Returns:
177
+ Image.Image: Generated image as a PIL image.
178
+ """
179
+ if style_img.ndim == 3:
180
+ style_img = style_img.unsqueeze(0)
181
+ elif style_img.ndim == 4: # (b, c, h, w)
182
+ pass
183
+ else:
184
+ raise ValueError('style_img must be 3D or 4D')
185
+
186
+ imgs, _ = self._generate(texts=[gen_text], imgs=style_img, **kwargs)
187
+ imgs = (imgs + 1) / 2
188
+ return F.to_pil_image(imgs[0].detach().cpu())
189
+
190
+ @torch.inference_mode()
191
+ def generate_batch(
192
+ self,
193
+ gen_texts: List[str],
194
+ style_imgs: torch.Tensor,
195
+ **kwargs: Any
196
+ ) -> List[Image.Image]:
197
+ """
198
+ Generate a batch of images from lists of style texts, generation texts, and style images.
199
+ Args:
200
+ style_texts (List[str]): List of style-related text prompts.
201
+ gen_texts (List[str]): List of generation-related text prompts.
202
+ style_imgs (torch.Tensor): Batch of style images (4D tensor).
203
+ lengths (List[int]): List of lengths corresponding to each image.
204
+ **kwargs: Additional keyword arguments.
205
+ Returns:
206
+ List[Image.Image]: List of generated images as PIL images.
207
+ """
208
+ assert style_imgs.ndim == 4, 'style_imgs must be 4D'
209
+ assert len(gen_texts) == len(style_imgs), 'gen_texts and style_imgs must have the same length'
210
+ texts = [gen_text for gen_text in gen_texts]
211
+
212
+ imgs, _ = self._generate(texts=texts, imgs=style_imgs, **kwargs)
213
+ imgs = (imgs + 1) / 2
214
+
215
+ out_imgs = []
216
+ for i in range(imgs.size(0)):
217
+ out_imgs.append(F.to_pil_image(imgs[i].detach().cpu()))
218
+ return out_imgs
219
+
220
+ def _generate(
221
+ self,
222
+ texts: Optional[List[str]] = None,
223
+ imgs: Optional[torch.Tensor] = None,
224
+ input_ids: Optional[torch.Tensor] = None,
225
+ z_sequence: Optional[torch.Tensor] = None,
226
+ max_new_tokens: int = 256,
227
+ stopping_criteria: str = 'latent',
228
+ stopping_after: int = 10,
229
+ stopping_patience: int = 1
230
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
231
+ """
232
+ Internal generation routine that combines textual and visual inputs to iteratively generate
233
+ latent representations and decode them into images.
234
+ Args:
235
+ texts (Optional[List[str]]): List of text prompts.
236
+ imgs (Optional[torch.Tensor]): Input image tensor.
237
+ lengths (Optional[List[int]]): Desired lengths for each image in latent space.
238
+ input_ids (Optional[torch.Tensor]): Tokenized input IDs.
239
+ z_sequence (Optional[torch.Tensor]): Precomputed latent sequence.
240
+ max_new_tokens (int): Maximum tokens to generate.
241
+ stopping_criteria (str): Criteria for stopping ('latent' or 'none').
242
+ stopping_after (int): Number of tokens to check for stopping condition.
243
+ stopping_patience (int): Patience parameter for stopping condition.
244
+ Returns:
245
+ Tuple containing:
246
+ - imgs (torch.Tensor): Generated images.
247
+ - canvas_sequence (torch.Tensor): Generated latent canvas sequence.
248
+ - img_ends (torch.Tensor): End indices for each generated image.
249
+ """
250
+ assert texts is not None or input_ids is not None, 'Either texts or input_ids must be provided'
251
+ assert imgs is not None or z_sequence is not None, 'Either imgs or z_sequence must be provided'
252
+
253
+ if input_ids is None:
254
+ input_ids = self.tokenizer(texts, return_tensors='pt', padding=True).input_ids
255
+ input_ids = input_ids.to(self.device)
256
+
257
+ if z_sequence is None:
258
+ posterior_style = self.vae.encode(imgs.float())
259
+ z_style = posterior_style.latent_dist.sample()
260
+ z_style_sequence = self.query_rearrange(z_style)
261
+ z_sequence = z_style_sequence
262
+
263
+ if self.style_enc == "mean":
264
+ style_global = z_sequence.mean(dim=1, keepdim=True) # (b, 1, d)
265
+ elif self.style_enc == "MLP" or self.style_enc == "MLP2":
266
+ style_scores = self.style_encoder(z_sequence) # (b, w, 1)
267
+ style_weights = torch.softmax(style_scores, dim=1) # (b, w, 1)
268
+ style_global = (z_sequence * style_weights).sum(dim=1, keepdim=True) # (b, 1, d)
269
+ elif self.style_enc == "full":
270
+ style_global = z_sequence # (b, w, d)
271
+ else:
272
+ raise ValueError(f"Unknown style_enc type: {self.style_enc}")
273
+
274
+ w_style = style_global.size(1)
275
+ style_token_embed = self.vae_to_t5(style_global) # (b, w_style, t5_d_model)
276
+
277
+ # prepare for decoder input
278
+ sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
279
+ pad_token = repeat(self.padding_token, '1 d -> b 1 d', b=input_ids.size(0))
280
+
281
+ generated_latents: List[torch.Tensor] = []
282
+ active = torch.ones(input_ids.size(0), dtype=torch.bool, device=self.device) # for batch processing
283
+
284
+ for step in range(max_new_tokens):
285
+ if len(generated_latents) == 0:
286
+ decoder_inputs_embeds = torch.cat([style_token_embed, sos], dim=1) # (b, w_style + 1, t5_d_model)
287
+ else:
288
+ lat_seq = torch.stack(generated_latents, dim=1) # (b, t, vae_latent_size)
289
+ lat_embeds = self.vae_to_t5(lat_seq) # (b, t, t5_d_model)
290
+ decoder_inputs_embeds = torch.cat([style_token_embed, sos, lat_embeds], dim=1) # (b, w_style + 1 + t, t5_d_model)
291
+
292
+ output = self.T5(input_ids, decoder_inputs_embeds=decoder_inputs_embeds)
293
+ last_hidden = output.logits[:, -1:, :] # (b, 1, t5_d_model)
294
+ vae_latent = self.t5_to_vae(last_hidden)[:, 0, :] # (b, vae_latent_size)
295
+
296
+ if stopping_criteria == 'latent' and (~active).any():
297
+ vae_latent = torch.where(
298
+ active.unsqueeze(-1),
299
+ vae_latent,
300
+ pad_token.squeeze(1)
301
+ )
302
+
303
+ generated_latents.append(vae_latent)
304
+ canvas_sequence = torch.stack(generated_latents, dim=1) # (b, t+1, vae_latent_size)
305
+
306
+ if stopping_criteria == 'latent':
307
+ similarity = torch.nn.functional.cosine_similarity(canvas_sequence, pad_token, dim=-1) # (b, t+1)
308
+ if similarity.size(1) >= stopping_after:
309
+ window = similarity[:, -stopping_after:] # (b, stopping_after)
310
+ cnt = (window > self.padding_token_threshold).to(torch.int).sum(dim=1) # (b,)
311
+ new = (cnt >= (stopping_after - stopping_patience)) & active # (b,)
312
+ active = active & (~new)
313
+
314
+ if not active.any():
315
+ break
316
+
317
+ elif stopping_criteria == 'none':
318
+ pass
319
+
320
+ canvas_sequence = torch.stack(generated_latents, dim=1) # (b, t, vae_latent_size)
321
+ imgs = torch.clamp(self.vae.decode(self.z_rearrange(canvas_sequence)).sample, -1, 1)
322
+ return imgs, canvas_sequence
323
+
324
+ def _img_encode(
325
+ self,
326
+ img: torch.Tensor,
327
+ noise: float = 0
328
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
329
+ """
330
+ Encode the input image into a latent representation using the VAE.
331
+ Args:
332
+ img (torch.Tensor): Input image tensor.
333
+ noise (float): Standard deviation of noise to add to the latent sequence.
334
+ Returns:
335
+ Tuple containing:
336
+ - decoder_inputs_embeds (torch.Tensor): Embeddings to be used as T5 decoder inputs.
337
+ - z_sequence (torch.Tensor): Rearranged latent sequence from the VAE.
338
+ - z (torch.Tensor): Sampled latent vector from the VAE.
339
+ """
340
+ posterior = self.vae.encode(img.float())
341
+ z = posterior.latent_dist.sample()
342
+ z_sequence = self.query_rearrange(z)
343
+
344
+ noise_sequence = z_sequence
345
+ if noise > 0:
346
+ noise_sequence = z_sequence + torch.randn_like(z_sequence) * noise
347
+
348
+ decoder_inputs_embeds = self.vae_to_t5(noise_sequence)
349
+ sos = repeat(self.sos.weight, '1 d -> b 1 d', b=decoder_inputs_embeds.size(0))
350
+ decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
351
+ return decoder_inputs_embeds, z_sequence, z
352
+
353
+ def compute_padding_token(self) -> None:
354
+ """
355
+ Compute and update the padding token.
356
+ Raises:
357
+ NotImplementedError: This method must be implemented.
358
+ """
359
+ raise NotImplementedError("compute_padding_token not implemented")
360
+
361
+ def compute_padding_token_threshold(self) -> None:
362
+ """
363
+ Compute and update the padding token threshold.
364
+ Raises:
365
+ NotImplementedError: This method must be implemented.
366
+ """
367
+ raise NotImplementedError("compute_padding_token_threshold not implemented")