Ruian7P commited on
Commit
9baf7c9
·
verified ·
1 Parent(s): 28b2302

Upload 4 files

Browse files
Files changed (4) hide show
  1. config.json +18 -0
  2. configuration_imuru.py +20 -0
  3. model.safetensors +3 -0
  4. modeling_imuru.py +412 -0
config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Imuru"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_imuru.ImuruConfig",
7
+ "AutoModel": "modeling_imuru.Imuru"
8
+ },
9
+ "model_type": "imuru",
10
+ "slices_per_query": 1,
11
+ "style_enc": "full",
12
+ "t5_name_or_path": "google-t5/t5-small",
13
+ "tokenizer_name_or_path": "google/byt5-small",
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.54.0",
16
+ "vae_channels": 1,
17
+ "vae_name_or_path": "blowing-up-groundhogs/emuru_vae"
18
+ }
configuration_imuru.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class ImuruConfig(PretrainedConfig):
4
+ model_type = "emuru"
5
+
6
+ def __init__(self,
7
+ t5_name_or_path='google-t5/t5-large',
8
+ vae_name_or_path='blowing-up-groundhogs/emuru_vae',
9
+ tokenizer_name_or_path='google/byt5-small',
10
+ slices_per_query=1,
11
+ vae_channels=1,
12
+ style_enc="mean",
13
+ **kwargs):
14
+ super().__init__(**kwargs)
15
+ self.t5_name_or_path = t5_name_or_path
16
+ self.vae_name_or_path = vae_name_or_path
17
+ self.tokenizer_name_or_path = tokenizer_name_or_path
18
+ self.slices_per_query = slices_per_query
19
+ self.vae_channels = vae_channels
20
+ self.style_enc = style_enc
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ef6240c86d9f90b5ad436e5b58e9774d7b574001145591734e4339274ac7219
3
+ size 232979496
modeling_imuru.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, T5ForConditionalGeneration, T5Config, AutoTokenizer
4
+ from .configuration_imuru import ImuruConfig
5
+ from diffusers import AutoencoderKL
6
+ from einops.layers.torch import Rearrange
7
+ from einops import repeat
8
+ from torchvision.transforms import functional as F
9
+ from typing import Optional, Tuple, List, Any
10
+ from PIL import Image
11
+
12
+ class Imuru(PreTrainedModel):
13
+ """
14
+ Imuru is a conditional generative model that integrates a T5-based decoder with a VAE
15
+ for image generation conditioned on text and style images.
16
+ Attributes:
17
+ config_class (Type): Configuration class for the model.
18
+ tokenizer (AutoTokenizer): Tokenizer loaded from the provided tokenizer configuration.
19
+ T5 (T5ForConditionalGeneration): T5 model adapted for conditional generation.
20
+ sos (nn.Embedding): Start-of-sequence embedding.
21
+ vae_to_t5 (nn.Linear): Linear projection from VAE latent space to T5 hidden space.
22
+ t5_to_vae (nn.Linear): Linear projection from T5 hidden space back to VAE latent space.
23
+ padding_token (nn.Parameter): Non-trainable parameter for padding tokens.
24
+ padding_token_threshold (nn.Parameter): Non-trainable parameter for padding token threshold.
25
+ vae (AutoencoderKL): Pre-trained Variational Autoencoder.
26
+ query_rearrange (Rearrange): Layer to rearrange VAE latent representations for queries.
27
+ z_rearrange (Rearrange): Layer to rearrange T5 outputs back to VAE latent dimensions.
28
+ mse_criterion (nn.MSELoss): Mean squared error loss function.
29
+ """
30
+ config_class = ImuruConfig
31
+
32
+ def __init__(self, config: ImuruConfig) -> None:
33
+ """
34
+ Initialize the Imuru model.
35
+ Args:
36
+ config (ImuruConfig): Configuration object containing model hyperparameters and paths.
37
+ """
38
+ super().__init__(config)
39
+
40
+ self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
41
+
42
+ t5_config = T5Config.from_pretrained(config.t5_name_or_path)
43
+ t5_config.vocab_size = len(self.tokenizer)
44
+ self.T5 = T5ForConditionalGeneration(t5_config)
45
+ self.T5.lm_head = nn.Identity()
46
+ self.sos = nn.Embedding(1, t5_config.d_model)
47
+
48
+ vae_latent_size = 8 * config.vae_channels * config.slices_per_query
49
+ self.vae_to_t5 = nn.Linear(vae_latent_size, t5_config.d_model)
50
+ self.t5_to_vae = nn.Linear(t5_config.d_model, vae_latent_size, bias=False)
51
+
52
+ self.padding_token = nn.Parameter( torch.tensor([[-0.4951, 0.8021, 0.3429, 0.5622, 0.5271, 0.5756, 0.7194, 0.6150]]), requires_grad=False)
53
+ self.padding_token_threshold = nn.Parameter(torch.tensor(0.484982096850872), requires_grad=False)
54
+
55
+ self.query_rearrange = Rearrange('b c h (w q) -> b w (q c h)', q=config.slices_per_query)
56
+ self.z_rearrange = Rearrange('b w (q c h) -> b c h (w q)', c=config.vae_channels, q=config.slices_per_query)
57
+
58
+ self.style_enc = config.style_enc if hasattr(config, 'style_enc') else "mean"
59
+ print(f"Using style encoder: {self.style_enc}")
60
+ if self.style_enc == "MLP": # w -> 1
61
+ self.style_encoder = nn.Linear(vae_latent_size, 1)
62
+ elif self.style_enc == "MLP2":
63
+ self.style_encoder = nn.Sequential(
64
+ nn.Linear(vae_latent_size, vae_latent_size),
65
+ nn.SiLU(),
66
+ nn.Linear(vae_latent_size, 1)
67
+ )
68
+
69
+ self.mse_criterion = nn.MSELoss()
70
+ self.init_weights()
71
+
72
+ self.vae = AutoencoderKL.from_pretrained(config.vae_name_or_path)
73
+ self.set_training(self.vae, False)
74
+
75
+ def set_training(self, model: nn.Module, training: bool) -> None:
76
+ """
77
+ Set the training mode for a given model and freeze/unfreeze parameters accordingly.
78
+ Args:
79
+ model (nn.Module): The model to set the training mode for.
80
+ training (bool): If True, set the model to training mode; otherwise, evaluation mode.
81
+ """
82
+ model.train() if training else model.eval()
83
+ for param in model.parameters():
84
+ param.requires_grad = training
85
+
86
+ def forward_nonAR(
87
+ self,
88
+ img: Optional[torch.Tensor] = None,
89
+ input_ids: Optional[torch.Tensor] = None,
90
+ attention_mask: Optional[torch.Tensor] = None,
91
+ noise: float = 0,
92
+ label_img: Optional[torch.Tensor] = None,
93
+ **kwargs: Any
94
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
95
+ """
96
+ Forward pass of the model.
97
+ Args:
98
+ img (Optional[torch.Tensor]): Input Style image tensor.
99
+ input_ids (Optional[torch.Tensor]): Tokenized input IDs.
100
+ attention_mask (Optional[torch.Tensor]): Attention mask for the inputs.
101
+ noise (float): Amount of noise to add in image encoding.
102
+ **kwargs: Additional arguments.
103
+ Returns:
104
+ Tuple containing:
105
+ - mse_loss (torch.Tensor): Mean squared error loss.
106
+ - pred_latent (torch.Tensor): Predicted latent representations.
107
+ - z (torch.Tensor): Sampled latent vector from VAE.
108
+ """
109
+ decoder_inputs_embeds, z_sequence, z = self._img_encode(img, noise)
110
+
111
+ output = self.T5(input_ids, attention_mask=attention_mask, decoder_inputs_embeds=decoder_inputs_embeds)
112
+ vae_latent = self.t5_to_vae(output.logits[:, :-1])
113
+ pred_latent = self.z_rearrange(vae_latent)
114
+
115
+ if self.training:
116
+ assert label_img is not None, 'label_img must be provided during training'
117
+ posterior_label = self.vae.encode(label_img.float())
118
+ z_label = posterior_label.latent_dist.sample()
119
+ z_label_sequence = self.query_rearrange(z_label)
120
+
121
+ # Fix: Ensure sequence lengths match for loss computation
122
+ min_seq_len = min(vae_latent.size(1), z_label_sequence.size(1))
123
+ vae_latent_trimmed = vae_latent[:, :min_seq_len]
124
+ z_label_sequence_trimmed = z_label_sequence[:, :min_seq_len]
125
+
126
+ mse_loss = self.mse_criterion(vae_latent_trimmed, z_label_sequence_trimmed)
127
+ else:
128
+ mse_loss = torch.tensor(0.0, device=self.device)
129
+ return mse_loss, pred_latent, z
130
+
131
+ def forward(
132
+ self,
133
+ img: Optional[torch.Tensor] = None,
134
+ label_img: Optional[torch.Tensor] = None,
135
+ input_ids: Optional[torch.Tensor] = None,
136
+ attention_mask: Optional[torch.Tensor] = None,
137
+ style_noise: float = 0,
138
+ label_noise: float = 0,
139
+ **kwargs: Any
140
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
141
+ """
142
+ Forward pass of the model.
143
+ Args:
144
+ img (Optional[torch.Tensor]): Input Style image tensor.
145
+ input_ids (Optional[torch.Tensor]): Tokenized input IDs.
146
+ attention_mask (Optional[torch.Tensor]): Attention mask for the inputs.
147
+ noise (float): Amount of noise to add in image encoding.
148
+ **kwargs: Additional arguments.
149
+ Returns:
150
+ Tuple containing:
151
+ - mse_loss (torch.Tensor): Mean squared error loss.
152
+ - pred_latent (torch.Tensor): Predicted latent representations.
153
+ - z (torch.Tensor): Sampled latent vector from VAE.
154
+ """
155
+ assert label_img is not None, 'label_img must be provided during training'
156
+
157
+ posterior_style = self.vae.encode(img.float())
158
+ z_style = posterior_style.latent_dist.sample()
159
+ z_style_sequence = self.query_rearrange(z_style)
160
+ if style_noise > 0:
161
+ z_style_sequence = z_style_sequence + torch.randn_like(z_style_sequence) * style_noise
162
+
163
+ if self.style_enc == "mean":
164
+ style_global = z_style_sequence.mean(dim=1, keepdim=True) # (b, 1, d)
165
+ elif self.style_enc in ["MLP", "MLP2"]:
166
+ style_scores = self.style_encoder(z_style_sequence) # (b, w, 1)
167
+ style_weights = torch.softmax(style_scores, dim=1) # (b, w, 1)
168
+ style_global = (z_style_sequence * style_weights).sum(dim=1, keepdim=True) # (b, 1, d)
169
+ elif self.style_enc == "full":
170
+ style_global = z_style_sequence # (b, w, d)
171
+ else:
172
+ raise ValueError(f"Unknown style_enc type: {self.style_enc}")
173
+
174
+ w_style = style_global.size(1)
175
+ style_token_embed = self.vae_to_t5(style_global) # (b, 1, t5_d_model)
176
+
177
+ posterior_label = self.vae.encode(label_img.float())
178
+ z_label = posterior_label.latent_dist.sample()
179
+ z_label_sequence = self.query_rearrange(z_label) # (b, w, d)
180
+
181
+ if label_noise > 0:
182
+ z_label_sequence_noisy = z_label_sequence + torch.randn_like(z_label_sequence) * label_noise
183
+ else:
184
+ z_label_sequence_noisy = z_label_sequence
185
+
186
+ sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
187
+
188
+ label_embeds = self.vae_to_t5(z_label_sequence_noisy) # (b, w, t5_d_model)
189
+ decoder_inputs_embeds = torch.cat(
190
+ [style_token_embed, sos, label_embeds[:, :-1]], dim=1
191
+ ) # (b, w_style + 1 + (w_label -1), t5_d_model)
192
+
193
+ output = self.T5(input_ids, attention_mask=attention_mask, decoder_inputs_embeds=decoder_inputs_embeds) # (b, w_style + 1 + (w_label -1), t5_d_model)
194
+ all_vae_latent = self.t5_to_vae(output.logits) # (b, w_style + 1 + (w_label -1), vae_latent_size)
195
+
196
+ vae_latent = all_vae_latent[:, w_style:, :] # (b, w_label, vae_latent_size)
197
+
198
+ min_seq_len = min(vae_latent.size(1), z_label_sequence.size(1))
199
+ vae_latent_trimmed = vae_latent[:, :min_seq_len]
200
+ z_label_sequence_trimmed = z_label_sequence[:, :min_seq_len]
201
+
202
+ mse_loss = self.mse_criterion(vae_latent_trimmed, z_label_sequence_trimmed)
203
+ pred_latent = self.z_rearrange(vae_latent_trimmed)
204
+
205
+ return mse_loss, pred_latent, z_label
206
+
207
+ @torch.inference_mode()
208
+ def generate(
209
+ self,
210
+ gen_text: str,
211
+ style_img: torch.Tensor,
212
+ **kwargs: Any
213
+ ) -> Image.Image:
214
+ """
215
+ Generate an image by combining style and generation texts with a style image.
216
+ Args:
217
+ style_text (str): Style-related text prompt.
218
+ gen_text (str): Generation-related text prompt.
219
+ style_img (torch.Tensor): Style image tensor. Expected shape is either 3D or 4D.
220
+ **kwargs: Additional keyword arguments.
221
+ Returns:
222
+ Image.Image: Generated image as a PIL image.
223
+ """
224
+ if style_img.ndim == 3:
225
+ style_img = style_img.unsqueeze(0)
226
+ elif style_img.ndim == 4: # (b, c, h, w)
227
+ pass
228
+ else:
229
+ raise ValueError('style_img must be 3D or 4D')
230
+
231
+ imgs, _ = self._generate(texts=[gen_text], imgs=style_img, **kwargs)
232
+ imgs = (imgs + 1) / 2
233
+ return F.to_pil_image(imgs[0].detach().cpu())
234
+
235
+ @torch.inference_mode()
236
+ def generate_batch(
237
+ self,
238
+ gen_texts: List[str],
239
+ style_imgs: torch.Tensor,
240
+ **kwargs: Any
241
+ ) -> List[Image.Image]:
242
+ """
243
+ Generate a batch of images from lists of style texts, generation texts, and style images.
244
+ Args:
245
+ style_texts (List[str]): List of style-related text prompts.
246
+ gen_texts (List[str]): List of generation-related text prompts.
247
+ style_imgs (torch.Tensor): Batch of style images (4D tensor).
248
+ lengths (List[int]): List of lengths corresponding to each image.
249
+ **kwargs: Additional keyword arguments.
250
+ Returns:
251
+ List[Image.Image]: List of generated images as PIL images.
252
+ """
253
+ assert style_imgs.ndim == 4, 'style_imgs must be 4D'
254
+ assert len(gen_texts) == len(style_imgs), 'gen_texts and style_imgs must have the same length'
255
+ texts = [gen_text for gen_text in gen_texts]
256
+
257
+ imgs, _ = self._generate(texts=texts, imgs=style_imgs, **kwargs)
258
+ imgs = (imgs + 1) / 2
259
+
260
+ out_imgs = []
261
+ for i in range(imgs.size(0)):
262
+ out_imgs.append(F.to_pil_image(imgs[i].detach().cpu()))
263
+ return out_imgs
264
+
265
+ def _generate(
266
+ self,
267
+ texts: Optional[List[str]] = None,
268
+ imgs: Optional[torch.Tensor] = None,
269
+ input_ids: Optional[torch.Tensor] = None,
270
+ z_sequence: Optional[torch.Tensor] = None,
271
+ max_new_tokens: int = 256,
272
+ stopping_criteria: str = 'latent',
273
+ stopping_after: int = 10,
274
+ stopping_patience: int = 1
275
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
276
+ """
277
+ Internal generation routine that combines textual and visual inputs to iteratively generate
278
+ latent representations and decode them into images.
279
+ Args:
280
+ texts (Optional[List[str]]): List of text prompts.
281
+ imgs (Optional[torch.Tensor]): Input image tensor.
282
+ lengths (Optional[List[int]]): Desired lengths for each image in latent space.
283
+ input_ids (Optional[torch.Tensor]): Tokenized input IDs.
284
+ z_sequence (Optional[torch.Tensor]): Precomputed latent sequence.
285
+ max_new_tokens (int): Maximum tokens to generate.
286
+ stopping_criteria (str): Criteria for stopping ('latent' or 'none').
287
+ stopping_after (int): Number of tokens to check for stopping condition.
288
+ stopping_patience (int): Patience parameter for stopping condition.
289
+ Returns:
290
+ Tuple containing:
291
+ - imgs (torch.Tensor): Generated images.
292
+ - canvas_sequence (torch.Tensor): Generated latent canvas sequence.
293
+ - img_ends (torch.Tensor): End indices for each generated image.
294
+ """
295
+ assert texts is not None or input_ids is not None, 'Either texts or input_ids must be provided'
296
+ assert imgs is not None or z_sequence is not None, 'Either imgs or z_sequence must be provided'
297
+
298
+ if input_ids is None:
299
+ input_ids = self.tokenizer(texts, return_tensors='pt', padding=True).input_ids
300
+ input_ids = input_ids.to(self.device)
301
+
302
+ if z_sequence is None:
303
+ posterior_style = self.vae.encode(imgs.float())
304
+ z_style = posterior_style.latent_dist.sample()
305
+ z_style_sequence = self.query_rearrange(z_style)
306
+ z_sequence = z_style_sequence
307
+
308
+ if self.style_enc == "mean":
309
+ style_global = z_sequence.mean(dim=1, keepdim=True) # (b, 1, d)
310
+ elif self.style_enc == "MLP" or self.style_enc == "MLP2":
311
+ style_scores = self.style_encoder(z_sequence) # (b, w, 1)
312
+ style_weights = torch.softmax(style_scores, dim=1) # (b, w, 1)
313
+ style_global = (z_sequence * style_weights).sum(dim=1, keepdim=True) # (b, 1, d)
314
+ elif self.style_enc == "full":
315
+ style_global = z_sequence # (b, w, d)
316
+ else:
317
+ raise ValueError(f"Unknown style_enc type: {self.style_enc}")
318
+
319
+ w_style = style_global.size(1)
320
+ style_token_embed = self.vae_to_t5(style_global) # (b, w_style, t5_d_model)
321
+
322
+ # prepare for decoder input
323
+ sos = repeat(self.sos.weight, '1 d -> b 1 d', b=input_ids.size(0))
324
+ pad_token = repeat(self.padding_token, '1 d -> b 1 d', b=input_ids.size(0))
325
+
326
+ generated_latents: List[torch.Tensor] = []
327
+ active = torch.ones(input_ids.size(0), dtype=torch.bool, device=self.device) # for batch processing
328
+
329
+ for step in range(max_new_tokens):
330
+ if len(generated_latents) == 0:
331
+ decoder_inputs_embeds = torch.cat([style_token_embed, sos], dim=1) # (b, w_style + 1, t5_d_model)
332
+ else:
333
+ lat_seq = torch.stack(generated_latents, dim=1) # (b, t, vae_latent_size)
334
+ lat_embeds = self.vae_to_t5(lat_seq) # (b, t, t5_d_model)
335
+ decoder_inputs_embeds = torch.cat([style_token_embed, sos, lat_embeds], dim=1) # (b, w_style + 1 + t, t5_d_model)
336
+
337
+ output = self.T5(input_ids, decoder_inputs_embeds=decoder_inputs_embeds)
338
+ last_hidden = output.logits[:, -1:, :] # (b, 1, t5_d_model)
339
+ vae_latent = self.t5_to_vae(last_hidden)[:, 0, :] # (b, vae_latent_size)
340
+
341
+ if stopping_criteria == 'latent' and (~active).any():
342
+ vae_latent = torch.where(
343
+ active.unsqueeze(-1),
344
+ vae_latent,
345
+ pad_token.squeeze(1)
346
+ )
347
+
348
+ generated_latents.append(vae_latent)
349
+ canvas_sequence = torch.stack(generated_latents, dim=1) # (b, t+1, vae_latent_size)
350
+
351
+ if stopping_criteria == 'latent':
352
+ similarity = torch.nn.functional.cosine_similarity(canvas_sequence, pad_token, dim=-1) # (b, t+1)
353
+ if similarity.size(1) >= stopping_after:
354
+ window = similarity[:, -stopping_after:] # (b, stopping_after)
355
+ cnt = (window > self.padding_token_threshold).to(torch.int).sum(dim=1) # (b,)
356
+ new = (cnt >= (stopping_after - stopping_patience)) & active # (b,)
357
+ active = active & (~new)
358
+
359
+ if not active.any():
360
+ break
361
+
362
+ elif stopping_criteria == 'none':
363
+ pass
364
+
365
+ canvas_sequence = torch.stack(generated_latents, dim=1) # (b, t, vae_latent_size)
366
+ imgs = torch.clamp(self.vae.decode(self.z_rearrange(canvas_sequence)).sample, -1, 1)
367
+ return imgs, canvas_sequence
368
+
369
+ def _img_encode(
370
+ self,
371
+ img: torch.Tensor,
372
+ noise: float = 0
373
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
374
+ """
375
+ Encode the input image into a latent representation using the VAE.
376
+ Args:
377
+ img (torch.Tensor): Input image tensor.
378
+ noise (float): Standard deviation of noise to add to the latent sequence.
379
+ Returns:
380
+ Tuple containing:
381
+ - decoder_inputs_embeds (torch.Tensor): Embeddings to be used as T5 decoder inputs.
382
+ - z_sequence (torch.Tensor): Rearranged latent sequence from the VAE.
383
+ - z (torch.Tensor): Sampled latent vector from the VAE.
384
+ """
385
+ posterior = self.vae.encode(img.float())
386
+ z = posterior.latent_dist.sample()
387
+ z_sequence = self.query_rearrange(z)
388
+
389
+ noise_sequence = z_sequence
390
+ if noise > 0:
391
+ noise_sequence = z_sequence + torch.randn_like(z_sequence) * noise
392
+
393
+ decoder_inputs_embeds = self.vae_to_t5(noise_sequence)
394
+ sos = repeat(self.sos.weight, '1 d -> b 1 d', b=decoder_inputs_embeds.size(0))
395
+ decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
396
+ return decoder_inputs_embeds, z_sequence, z
397
+
398
+ def compute_padding_token(self) -> None:
399
+ """
400
+ Compute and update the padding token.
401
+ Raises:
402
+ NotImplementedError: This method must be implemented.
403
+ """
404
+ raise NotImplementedError("compute_padding_token not implemented")
405
+
406
+ def compute_padding_token_threshold(self) -> None:
407
+ """
408
+ Compute and update the padding token threshold.
409
+ Raises:
410
+ NotImplementedError: This method must be implemented.
411
+ """
412
+ raise NotImplementedError("compute_padding_token_threshold not implemented")