carminezacc commited on
Commit
47e6fbb
·
verified ·
1 Parent(s): 96488d8

Upload eruku_continuous_inf.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. eruku_continuous_inf.py +569 -0
eruku_continuous_inf.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os as _os
3
+ from transformers import AutoTokenizer
4
+ from transformers import T5ForConditionalGeneration, T5Config
5
+ from custom_datasets import HFDataCollector
6
+ from einops.layers.torch import Rearrange
7
+ from einops import rearrange, repeat
8
+ from torch.nn import MSELoss, CTCLoss, CrossEntropyLoss
9
+ from pathlib import Path
10
+ from torchvision.utils import make_grid, save_image
11
+ from PIL import Image, ImageDraw, ImageFont
12
+ from models.origami import OrigamiNet
13
+ from diffusers import AutoencoderKL
14
+ from torch.nn.utils.rnn import pad_sequence
15
+ from torchvision.transforms import Normalize
16
+ import numpy as np
17
+ import torch.nn as nn
18
+ from typing import Tuple
19
+
20
+ # Safer defaults for clearer NCCL/CUDA error reporting during debugging
21
+ _os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
22
+ _os.environ.setdefault("TORCH_NCCL_BLOCKING_WAIT", "1")
23
+ _os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1")
24
+
25
+
26
+ def _safe_int_from_maybe_tensor(value, fallback_min: int = 64) -> int:
27
+ """Convert a python int or 0-dim tensor (cpu/cuda) to int safely.
28
+
29
+ - Synchronizes CUDA before .item() to surface the true failing kernel site
30
+ - Moves to CPU before scalarization
31
+ - Falls back to a reasonable minimum on unexpected errors
32
+ """
33
+ try:
34
+ if isinstance(value, torch.Tensor):
35
+ scalar_tensor = value
36
+ # Take the first element if tensor is not scalar
37
+ if scalar_tensor.dim() > 0:
38
+ scalar_tensor = scalar_tensor.reshape(-1)[0]
39
+ # Synchronize to attribute errors to the right op during debug
40
+ if scalar_tensor.is_cuda:
41
+ try:
42
+ torch.cuda.synchronize(scalar_tensor.device)
43
+ except Exception:
44
+ pass
45
+ return int(scalar_tensor.detach().to("cpu").item())
46
+ return int(value)
47
+ except Exception:
48
+ # As a last resort, return a conservative minimum width
49
+ return int(fallback_min)
50
+
51
+ def pad_images(images, padding_value=1):
52
+ images = [rearrange(img, 'c h w -> w c h') for img in images]
53
+ padded = rearrange(pad_sequence(images, padding_value=padding_value), 'w b c h -> b c h w')
54
+ return padded.contiguous()
55
+
56
+
57
+
58
+
59
+ # sog, eog, img
60
+ SPECIAL_TOKEN_COUNT = 3
61
+
62
+ class Emuru(torch.nn.Module):
63
+ def __init__(self, t5_checkpoint='google-t5/t5-base',
64
+ vae_checkpoint='blowing-up-groundhogs/emuru_vae',
65
+ ocr_checkpoint='files/checkpoints/Origami_bw_img/origami.pth', slices_per_query=1, channels=1, text_dropout_probability=0.0, img_dropout_probability=0.0):
66
+ super(Emuru, self).__init__()
67
+ self.tokenizer = AutoTokenizer.from_pretrained('google/byt5-small') # per-character tokenizer
68
+ self.tokenizer.add_tokens(["<sog>"])
69
+ self.data_collator = HFDataCollector(tokenizer=self.tokenizer)
70
+ self.t5_name_or_path = t5_checkpoint
71
+
72
+ self.padding_token = torch.tensor([[-0.4951, 0.8021, 0.3429, 0.5622, 0.5271, 0.5756, 0.7194, 0.6150]])
73
+ self.padding_token_threshold = 0.484982096850872
74
+
75
+ config = T5Config.from_pretrained(t5_checkpoint)
76
+ config.vocab_size = len(self.tokenizer)
77
+ self.T5 = T5ForConditionalGeneration(config)
78
+ # Expose a HF-like config for downstream trainers expecting model.config
79
+ self.config = self.T5.config
80
+ # Ensure a valid identifier is present for downstream AutoProcessor lookups
81
+ try:
82
+ if not getattr(self.config, "_name_or_path", None):
83
+ self.config._name_or_path = str(self.t5_name_or_path)
84
+ except Exception:
85
+ # As a safe fallback, set attribute directly
86
+ self.config._name_or_path = str(self.t5_name_or_path)
87
+ self.T5.lm_head = torch.nn.Identity()
88
+ self.normalize = Normalize(0.5, 0.5)
89
+ self.sos = torch.nn.Embedding(1, config.d_model)
90
+ self.sog = torch.nn.Embedding(1, config.d_model)
91
+ self.eog = torch.nn.Embedding(1, config.d_model)
92
+
93
+ self.vae = AutoencoderKL.from_pretrained(vae_checkpoint)
94
+
95
+ vae_latent_dim = 8 # self.vae.config.get('latent_channels', 8)
96
+
97
+ self.query_emb = torch.nn.Linear(vae_latent_dim * channels * slices_per_query, config.d_model)
98
+ self.t5_to_vae = torch.nn.Linear(config.d_model, vae_latent_dim * channels * slices_per_query)
99
+ self.t5_to_special = torch.nn.Linear(config.d_model, SPECIAL_TOKEN_COUNT)
100
+ self.t5_to_ocr = torch.nn.Linear(config.d_model, len(self.tokenizer), bias=False)
101
+
102
+ self.uncond_embedding = torch.nn.Embedding(1, config.d_model)
103
+ self.dropout_probability = 0.0
104
+ self.drop_text = False
105
+ self.drop_img = False
106
+
107
+ self.set_training(self.vae, False)
108
+
109
+ self.ocr = OrigamiNet.from_checkpoint(ocr_checkpoint, o_classes=165, n_channels=1)
110
+ self.set_training(self.ocr, False)
111
+
112
+ self.query_rearrange = Rearrange('b c h (w q) -> b w (q c h)', q=slices_per_query)
113
+ self.special_rearrange = torch.nn.Identity()
114
+ # self.special_rearrange = Rearrange('b w (h c) -> b w (h c)')
115
+ self.z_rearrange = Rearrange('b w (q c h) -> b c h (w q)', c=channels, q=slices_per_query)
116
+ self.z_rearrange_eval = Rearrange('w b (q c h) -> b c h (w q)', c=channels, q=slices_per_query)
117
+
118
+ self.mse_criterion = MSELoss()#(reduction='none') # TODO:change reductions if you intend to add a mask
119
+ self.ce_criterion = CrossEntropyLoss()
120
+ # self.ctc_criterion = CTCLoss()
121
+ self.trainer = None
122
+ self.alpha = 1.0
123
+ # Minimal attributes for TRL compatibility
124
+ self.warnings_issued = {}
125
+ self._model_tags = set()
126
+
127
+ def add_model_tags(self, tags):
128
+ try:
129
+ if isinstance(tags, (list, tuple, set)):
130
+ self._model_tags.update(tags)
131
+ elif isinstance(tags, str):
132
+ self._model_tags.add(tags)
133
+ except Exception:
134
+ # No-op if tags updating fails
135
+ pass
136
+
137
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
138
+ """Enable gradient checkpointing - delegate to T5 model"""
139
+ if hasattr(self.T5, 'gradient_checkpointing_enable'):
140
+ self.T5.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
141
+
142
+ def gradient_checkpointing_disable(self):
143
+ """Disable gradient checkpointing - delegate to T5 model"""
144
+ if hasattr(self.T5, 'gradient_checkpointing_disable'):
145
+ self.T5.gradient_checkpointing_disable()
146
+
147
+ def set_training(self, model, training):
148
+ model.train() if training else model.eval()
149
+ for param in model.parameters():
150
+ param.requires_grad = training
151
+
152
+ def _img_encode(self,img):
153
+ img = self.normalize(img)
154
+ # Ensure contiguous memory layout before encode to avoid kernel issues
155
+ img = img.contiguous()
156
+ return self.vae.encode(img.float()).latent_dist.sample()
157
+
158
+ @torch.no_grad()
159
+ def get_model_inputs(self, style_img, gen_img, style_len, gen_len, max_img_len):
160
+ bs = len(style_img)
161
+ decoder_inputs_embeds_list = []
162
+ specials_list = []
163
+
164
+ # Move images to device and pad them
165
+ style_img = pad_images([el.to(self.T5.device) for el in style_img])
166
+
167
+ if gen_img is not None:
168
+ gen_img = pad_images([el.to(self.T5.device) for el in gen_img])
169
+ gen_img_embeds = self._img_encode(gen_img)
170
+ else:
171
+ gen_img_embeds = None
172
+
173
+ style_img_embeds = self._img_encode(style_img)
174
+
175
+ for el in range(bs):
176
+ if isinstance(style_len, int):
177
+ sl = style_len
178
+ else:
179
+ # Safely get scalar style length
180
+ sl_tensor = style_len[el] if hasattr(style_len, '__getitem__') else style_len
181
+ sl = _safe_int_from_maybe_tensor(sl_tensor)
182
+
183
+ # Ensure widths are within bounds
184
+ sl = max(64, min(sl, style_img_embeds.shape[-1]))
185
+
186
+ # Start with style image embeds
187
+ sample_embeds_parts = [style_img_embeds[el,:,:,:sl//8]]
188
+ specials_parts = [torch.ones(sl//8) * 2] # Img token
189
+
190
+ if gen_img_embeds is not None and gen_len is not None:
191
+ if isinstance(gen_len, int):
192
+ gl = gen_len
193
+ else:
194
+ gl_tensor = gen_len[el] if hasattr(gen_len, '__getitem__') else gen_len
195
+ gl = _safe_int_from_maybe_tensor(gl_tensor)
196
+
197
+ gl = max(64, min(gl, gen_img_embeds.shape[-1]))
198
+ sample_embeds_parts.extend([
199
+ torch.ones(1, 8, 1).to(self.T5.device), # SOG token placeholder
200
+ gen_img_embeds[el,:,:,:gl//8],
201
+ torch.ones(1, 8, 1).to(self.T5.device), # EOG token placeholder
202
+ ])
203
+ specials_parts.extend([
204
+ torch.zeros(1), # SOG
205
+ torch.ones(gl//8) * 2, # Img
206
+ torch.ones(1), # EOG
207
+ ])
208
+
209
+ sample_embeds = torch.cat(sample_embeds_parts, dim=-1)
210
+
211
+ h_dim = sample_embeds.shape[1]
212
+ sample_embeds = rearrange(sample_embeds, 'c h w -> w (h c)', h=h_dim, c=1)
213
+
214
+ decoder_inputs_embeds_list.append(sample_embeds)
215
+
216
+ sample_specials = torch.cat(specials_parts, dim=0).to(self.T5.device)
217
+ specials_list.append(sample_specials)
218
+
219
+ # Pad sequences and ensure consistent shapes
220
+ decoder_inputs_embeds_padded = pad_sequence(decoder_inputs_embeds_list, padding_value=1, batch_first=True)
221
+ specials_padded = pad_sequence(specials_list, padding_value=1, batch_first=True)
222
+
223
+ # Ensure we don't exceed max_img_len
224
+ max_seq_len = max_img_len // 8
225
+ if decoder_inputs_embeds_padded.size(1) > max_seq_len:
226
+ decoder_inputs_embeds_padded = decoder_inputs_embeds_padded[:, :max_seq_len]
227
+ if specials_padded.size(1) > max_seq_len:
228
+ specials_padded = specials_padded[:, :max_seq_len]
229
+
230
+ return {
231
+ 'decoder_inputs_embeds': decoder_inputs_embeds_padded,
232
+ 'specials': specials_padded.long(),
233
+ }
234
+
235
+ def forward(self, decoder_inputs_embeds_vae, specials, style_text, gen_text, ce_multiplier=1.0):
236
+ # style_img_embeds: [bs, w//8, 8, 1]
237
+ # generate text embeddings
238
+
239
+ with torch.no_grad():
240
+ encoded_text = self.tokenizer([f"{style}<sog>{gen}" for style, gen in zip(style_text, gen_text)], padding=True, return_tensors="pt")
241
+
242
+ # add special tokens to img
243
+ sos = repeat(self.sos.weight, '1 d -> b 1 d', b=decoder_inputs_embeds_vae.size(0))
244
+ sog = repeat(self.sog.weight, '1 d -> b d', b=decoder_inputs_embeds_vae.size(0))
245
+ # eog = repeat(self.eog.weight, '1 d -> b d', b=decoder_inputs_embeds_vae.size(0))
246
+
247
+ decoder_inputs_embeds = self.query_emb(decoder_inputs_embeds_vae)
248
+
249
+ # Fix the indexing assignment to avoid shape mismatch
250
+ sog_mask = (specials == 0)
251
+ eog_mask = (specials == 1)
252
+
253
+ # Expand sog to match the sequence dimension
254
+ sog_expanded = sog.unsqueeze(1).expand(-1, decoder_inputs_embeds.size(1), -1)
255
+
256
+ if sog_mask.any():
257
+ decoder_inputs_embeds[sog_mask] = sog_expanded[sog_mask]
258
+
259
+ if eog_mask.any():
260
+ # Expand eog to match the sequence dimension
261
+ eog_expanded = self.eog.weight.unsqueeze(0).expand(decoder_inputs_embeds.size(0), decoder_inputs_embeds.size(1), -1)
262
+ decoder_inputs_embeds[eog_mask] = eog_expanded[eog_mask]
263
+
264
+ decoder_inputs_embeds = torch.cat(
265
+ [
266
+ sos,
267
+ decoder_inputs_embeds
268
+ ], dim = 1,
269
+ )
270
+
271
+ inputs_embeds = self.T5.shared(encoded_text['input_ids'].to(self.T5.device))
272
+ drop_ids = torch.rand(inputs_embeds.shape[0], device=inputs_embeds.device) < self.dropout_probability
273
+ if self.drop_text:
274
+ inputs_embeds = torch.where(drop_ids[:, None, None], self.uncond_embedding.weight, inputs_embeds)
275
+ if self.drop_img:
276
+ decoder_inputs_embeds = torch.where(drop_ids[:, None, None], self.uncond_embedding.weight, decoder_inputs_embeds)
277
+
278
+ output = self.T5(inputs_embeds=inputs_embeds, attention_mask=encoded_text['attention_mask'].to(self.T5.device), decoder_inputs_embeds=decoder_inputs_embeds)
279
+
280
+ vae_latent = self.t5_to_vae(output.logits[:, :-1])
281
+ special_latent = self.t5_to_special(output.logits[:, :-1]) # [bs, w//8, 3]
282
+ pred_latent = self.z_rearrange(vae_latent)
283
+ special_pred = self.special_rearrange(special_latent)
284
+
285
+
286
+ ce_loss = ce_multiplier * self.ce_criterion(special_pred.flatten(0,1), specials.flatten(0,1))
287
+
288
+ mse_mask = (specials == 2).unsqueeze(2) # [bs, w//8] TODO:consider putting the mask back in
289
+ gt = decoder_inputs_embeds_vae * mse_mask
290
+ vae_latent = vae_latent * mse_mask
291
+ mse_loss = self.mse_criterion(vae_latent, gt)#/mse_mask.sum()
292
+ ocr_loss = 0
293
+
294
+ if self.alpha < 1.0:
295
+ pred_img = self.vae.decode(pred_latent).sample
296
+ gt_img = self.vae.decode(decoder_inputs_embeds_vae.unsqueeze(1)).sample
297
+ ocr_preds = self.ocr(pred_img)
298
+ ocr_gt = self.ocr(gt_img)
299
+ ocr_loss = self.mse_criterion(ocr_preds, ocr_gt)
300
+ else:
301
+ ocr_loss = torch.tensor(0.0).to(mse_loss.device)
302
+ loss = (ce_loss + mse_loss) * self.alpha + ocr_loss * (1 - self.alpha)
303
+ return {'loss': loss, 'mse_loss': mse_loss, 'ce_loss': ce_loss, 'ocr_loss': ocr_loss}, pred_latent
304
+
305
+ def split_characters(self, pred, gt, indices):
306
+ pred = self.vae.decode(pred).sample
307
+ gt = self.vae.decode(gt).sample
308
+ img = torch.cat([gt, pred], dim=-2)
309
+
310
+ curr_char = indices[0]
311
+ for idx, char in enumerate(indices):
312
+ if char != curr_char:
313
+ img[:, :, :, idx * 8 - 1] = -1
314
+ curr_char = char
315
+
316
+ img = self.write_text_below_image(img, self.tokenizer.decode(indices))
317
+
318
+ return img
319
+
320
+
321
+ @torch.no_grad()
322
+ def write_text_below_image(self, image, text):
323
+ image = (torch.clamp(image, -1, 1) + 1) * 127.5
324
+ image = rearrange(image.to(torch.uint8), '1 1 h w -> h w').cpu().numpy()
325
+ image = Image.fromarray(image, mode='L')
326
+
327
+ text = text.replace('<pad>', '#').replace('</s>', '$')
328
+
329
+ # Load the font
330
+ font = ImageFont.load_default()
331
+ ascent, descent = font.getmetrics()
332
+ (width, baseline), (offset_x, offset_y) = font.font.getsize(text)
333
+
334
+ # Calculate dimensions for the new image
335
+ img_width, img_height = image.size
336
+ new_height = img_height + offset_y + ascent +descent
337
+
338
+ # Create a new image with white background
339
+ new_image = Image.new('L', (img_width, new_height), color='white')
340
+
341
+ # Paste the original image onto the new image
342
+ new_image.paste(image, (0, 0))
343
+
344
+ # Draw the text onto the new image
345
+ draw = ImageDraw.Draw(new_image)
346
+
347
+ curr_char = None
348
+ for idx, char in enumerate(text):
349
+ if char != curr_char:
350
+ curr_char = char
351
+ draw.text((idx * 8, img_height), char, fill='black', font=font)
352
+
353
+ return new_image
354
+
355
+ @torch.inference_mode()
356
+ def generate(self, decoder_inputs_embeds_vae, style_text, gen_text, cfg_scale=1.0, max_new_tokens=64):
357
+ """
358
+ call this with bs=1 please
359
+ """
360
+ encoded_text = self.tokenizer([f"{style}<sog>{gen}" for style, gen in zip(style_text,gen_text)], padding=True, return_tensors="pt")
361
+ text_input_ids = encoded_text['input_ids'].to(self.T5.device)
362
+ text_mask = encoded_text['attention_mask'].to(self.T5.device)
363
+
364
+ sog = repeat(self.sog.weight, '1 d -> b 1 d', b=1)
365
+ sos = repeat(self.sos.weight, '1 d -> b 1 d', b=1)
366
+ z_sequence = [decoder_inputs_embeds_vae]
367
+ special_sequence = torch.ones(decoder_inputs_embeds_vae.size(1))*3
368
+ if len(z_sequence) == 0:
369
+ decoder_inputs_embeds = sos
370
+ else:
371
+ decoder_inputs_embeds = self.query_emb(torch.cat(z_sequence, dim=1))
372
+ if len(style_text[0]) != 0:
373
+ decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds], dim=1)
374
+ else:
375
+ decoder_inputs_embeds = torch.cat([sos, decoder_inputs_embeds, sog], dim=1)
376
+ vae_latent = self.t5_to_vae(sog)
377
+ special_sequence = torch.cat([special_sequence, torch.zeros(1)])
378
+ z_sequence.append(vae_latent)
379
+
380
+ for i in range(max_new_tokens):
381
+ if cfg_scale != 1.0:
382
+ conditional_text_embeds = self.T5.shared(text_input_ids)
383
+ if self.drop_text:
384
+ unconditional_text_embeds = torch.zeros_like(conditional_text_embeds).to(self.T5.device) + self.uncond_embedding.weight
385
+ else:
386
+ unconditional_text_embeds = conditional_text_embeds
387
+
388
+ if self.drop_img:
389
+ unconditional_decoder_inputs_embeds = torch.zeros_like(decoder_inputs_embeds).to(self.T5.device) + self.uncond_embedding.weight
390
+ else:
391
+ unconditional_decoder_inputs_embeds = decoder_inputs_embeds
392
+
393
+ output_unconditional = self.T5(inputs_embeds=unconditional_text_embeds, attention_mask=text_mask, decoder_inputs_embeds=unconditional_decoder_inputs_embeds).logits[:, -1:]
394
+ output_conditional = self.T5(input_ids=text_input_ids, attention_mask=text_mask, decoder_inputs_embeds=decoder_inputs_embeds).logits[:, -1:]
395
+ output = output_unconditional + (output_conditional - output_unconditional) * cfg_scale
396
+ else:
397
+ output = self.T5(input_ids=text_input_ids, attention_mask=text_mask, decoder_inputs_embeds=decoder_inputs_embeds).logits[:, -1:]
398
+
399
+ special_prediction = self.t5_to_special(output)
400
+
401
+ if torch.argmax(special_prediction, dim=-1) == 0:
402
+ decoder_inputs_embeds = torch.cat([decoder_inputs_embeds, sog], dim=1)
403
+ vae_latent = self.t5_to_vae(output)
404
+ special_sequence = torch.cat([special_sequence, torch.zeros(1)])
405
+ elif torch.argmax(special_prediction, dim=-1) == 1:
406
+ special_sequence = torch.cat([special_sequence, torch.ones(1)])
407
+ vae_latent = self.t5_to_vae(output)
408
+ z_sequence.append(vae_latent)
409
+ break
410
+ else:
411
+ vae_latent = self.t5_to_vae(output)
412
+ decoder_inputs_embeds = torch.cat([decoder_inputs_embeds, self.query_emb(vae_latent)], dim=1)
413
+ special_sequence = torch.cat([special_sequence, torch.ones(1)*2])
414
+ z_sequence.append(vae_latent)
415
+
416
+
417
+ z_sequence = [el.to(self.vae.device) for el in z_sequence]
418
+
419
+ z_sequence = torch.cat(z_sequence, dim=1)
420
+ img = torch.clamp(self.vae.decode(self.z_rearrange(z_sequence)).sample, -1, 1)
421
+ return img, special_sequence.to(self.T5.device)
422
+
423
+ @torch.no_grad()
424
+ def continue_gen_test(self, gt, batch, max_new_tokens=64, cfg_scale=1.0):
425
+ gt = gt[:1]
426
+ def _continue_gen(style_len):
427
+
428
+ generation = self.generate(batch['decoder_inputs_embeds'][:1, :style_len], batch['style_text'][:1], batch['gen_text'][:1], cfg_scale=cfg_scale, max_new_tokens=max_new_tokens)
429
+ test_img = generation[0]
430
+ special_sequence = generation[1].repeat_interleave(8)
431
+
432
+
433
+ special_img = torch.zeros_like(test_img).repeat(1,3,1,1)
434
+ special_sequence = special_sequence[:special_img.size(-1)]
435
+ special_img[:,0,:,special_sequence == 2] = 1 # red: image
436
+ special_img[:,1,:,special_sequence == 0] = 1 # green: sog
437
+ special_img[:,2,:,special_sequence == 1] = 1 # blue: eog
438
+
439
+ try:
440
+ test_img[:, :, :, style_len * 8] = -1 # add a black line between style and pred
441
+ except:
442
+ print("couldn't add black line")
443
+ # add special_img to the bottom of test_img
444
+ test_img = torch.cat([test_img.repeat(1,3,1,1) , special_img], dim=-2)
445
+ return test_img
446
+
447
+ gt = torch.clamp(self.vae.decode(gt).sample, -1, 1)
448
+ if type(batch['style_img_width']) == torch.Tensor:
449
+ style_img_width = batch['style_img_width'][0]
450
+ else:
451
+ style_img_width = batch['style_img_width']
452
+
453
+ return torch.cat(list(pad_images([
454
+ # make_grid(_continue_gen(style_img_width//8-10), nrow=1, normalize=True),
455
+ make_grid(_continue_gen(style_img_width//8), nrow=1, normalize=True),
456
+ ])), dim=-2)
457
+
458
+
459
+ def save_pretrained(self, path):
460
+ path = Path(path)
461
+ path.mkdir(parents=True, exist_ok=True)
462
+ torch.save(self.T5.state_dict(), path / 'T5.pth')
463
+ torch.save(self.vae.state_dict(), path / 'VAE.pth')
464
+ torch.save(self.ocr.state_dict(), path / 'OCR.pth')
465
+ torch.save(self.query_emb.state_dict(), path / 'query_emb.pth')
466
+ torch.save(self.sos.state_dict(), path / 'sos.pth')
467
+
468
+ def load_pretrained(self, path):
469
+ path = Path(path)
470
+ self.T5.load_state_dict(torch.load(path / 'T5.pth'))
471
+ self.vae.load_state_dict(torch.load(path / 'VAE.pth'))
472
+ self.ocr.load_state_dict(torch.load(path / 'OCR.pth'))
473
+ self.query_emb.load_state_dict(torch.load(path / 'query_emb.pth'))
474
+ self.sos.load_state_dict(torch.load(path / 'sos.pth'))
475
+
476
+ class DDPCompatibleEmuru(Emuru):
477
+ def __init__(self, *args, **kwargs):
478
+ super().__init__(*args, **kwargs)
479
+
480
+ def forward(self, batch_data, mode='train'):
481
+ """
482
+ Unified forward method that handles different modes for DDP compatibility
483
+ """
484
+ if mode == 'train':
485
+ # Training mode - expects the full batch with model inputs already computed
486
+ return super().forward(
487
+ batch_data['decoder_inputs_embeds'],
488
+ batch_data['specials'],
489
+ batch_data['style_text'],
490
+ batch_data['gen_text']
491
+ )
492
+ elif mode == 'get_model_inputs':
493
+ # Mode to get model inputs
494
+ return super().get_model_inputs(
495
+ batch_data['style_img'],
496
+ batch_data['gen_img'],
497
+ batch_data['style_img_width'],
498
+ batch_data['gen_img_width'],
499
+ batch_data['max_img_len']
500
+ )
501
+ elif mode == 'generate':
502
+ # Generation mode
503
+ return super().generate(
504
+ batch_data['decoder_inputs_embeds_vae'],
505
+ batch_data['style_text'],
506
+ batch_data['gen_text'],
507
+ batch_data.get('cfg_scale', 1.0),
508
+ batch_data.get('max_new_tokens', 64)
509
+ )
510
+ elif mode == 'continue_gen_test':
511
+ # Continue generation test mode
512
+ return super().continue_gen_test(
513
+ batch_data['gt'],
514
+ batch_data['batch'],
515
+ batch_data.get('cfg_scale', 1.0),
516
+ batch_data.get('max_new_tokens', 64)
517
+ )
518
+ else:
519
+ raise ValueError(f"Unknown mode: {mode}")
520
+
521
+ def module_get_model_inputs(self, style_img, gen_img, style_len, gen_len, max_img_len):
522
+ """Direct access method for get_model_inputs when not using DDP forward"""
523
+ return super().get_model_inputs(style_img, gen_img, style_len, gen_len, max_img_len)
524
+
525
+ def module_continue_gen_test(self, gt, batch, max_new_tokens=64, cfg_scale=1.0):
526
+ """Direct access method for continue_gen_test when not using DDP forward"""
527
+ return super().continue_gen_test(gt, batch, max_new_tokens, cfg_scale)
528
+
529
+ def module_vae_decode(self, latents):
530
+ """Direct access method for VAE decode"""
531
+ return self.vae.decode(latents)
532
+
533
+ def get_trainable_parameters(self):
534
+ """
535
+ Get only the parameters that have requires_grad=True
536
+ Useful for creating optimizers with only trainable parameters
537
+ """
538
+ return [p for p in self.parameters() if p.requires_grad]
539
+
540
+ def get_parameter_count(self):
541
+ """
542
+ Get counts of total and trainable parameters
543
+ """
544
+ total_params = sum(p.numel() for p in self.parameters())
545
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
546
+ return {
547
+ 'total_parameters': total_params,
548
+ 'trainable_parameters': trainable_params,
549
+ 'frozen_parameters': total_params - trainable_params
550
+ }
551
+
552
+ def print_parameter_info(self):
553
+ """
554
+ Print detailed information about model parameters
555
+ """
556
+ info = self.get_parameter_count()
557
+ print(f"Model Parameter Info:")
558
+ print(f" Total parameters: {info['total_parameters']:,}")
559
+ print(f" Trainable parameters: {info['trainable_parameters']:,}")
560
+ print(f" Frozen parameters: {info['frozen_parameters']:,}")
561
+ print(f" Trainable ratio: {info['trainable_parameters']/info['total_parameters']:.2%}")
562
+
563
+ # Print per-module info
564
+ print(f"\nPer-module breakdown:")
565
+ for name, module in self.named_children():
566
+ module_total = sum(p.numel() for p in module.parameters())
567
+ module_trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
568
+ if module_total > 0:
569
+ print(f" {name}: {module_trainable:,}/{module_total:,} trainable ({module_trainable/module_total:.1%})")