File size: 22,416 Bytes
a09cfc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
import torch
import torch.nn.functional as F
from typing import NamedTuple, Optional
import os
from diffusers import DDPMPipeline, UNet2DConditionModel, DDPMScheduler
import json
# Running the main at the end of this requires messing with this import
from text_model import TransformerModel  
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import common_settings as common_settings
import sentence_transformers_helper as st_helper
import text_model as text_model
from general_training_helper import get_scene_from_embeddings
            
class PipelineOutput(NamedTuple):
    images: torch.Tensor
    


# Create a custom pipeline for text-conditional generation
class TextConditionalDDPMPipeline(DDPMPipeline):
    def __init__(self, unet, scheduler, text_encoder=None, tokenizer=None, supports_pretrained_split=False, block_embeddings=None):
        super().__init__(unet=unet, scheduler=scheduler)
        self.text_encoder = text_encoder
        self.tokenizer = tokenizer
        self.supports_negative_prompt = hasattr(unet, 'negative_prompt_support') and unet.negative_prompt_support
        self.supports_pretrained_split = supports_pretrained_split
        self.block_embeddings = block_embeddings

        if self.tokenizer is None and self.text_encoder is not None:
            # Use the tokenizer from the text encoder if not provided
            self.tokenizer = self.text_encoder.tokenizer
        
        # Register the text_encoder so that .to(), .cpu(), .cuda(), etc. work correctly
        self.register_modules(
            unet=unet,
            scheduler=scheduler,
            text_encoder=self.text_encoder,
            tokenizer=self.tokenizer,
        )
    
    # Override the to() method to ensure text_encoder is moved to the correct device
    def to(self, device=None, dtype=None):
        # Call the parent's to() method first
        pipeline = super().to(device, dtype)
        
        # Additionally move the text_encoder to the device
        if self.text_encoder is not None:
            self.text_encoder.to(device)
        
        return pipeline

    def save_pretrained(self, save_directory):
        os.makedirs(save_directory, exist_ok=True)
        super().save_pretrained(save_directory)  # saves UNet and scheduler

        # Save block_embeddings tensor if it exists
        if self.block_embeddings is not None:
            torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))

        # Save supports_negative_prompt and supports_pretrained_split flags
        with open(os.path.join(save_directory, "pipeline_config.json"), "w") as f:
            json.dump({
                "supports_negative_prompt": self.supports_negative_prompt,
                "supports_pretrained_split": self.supports_pretrained_split,
                "text_encoder_type": type(self.text_encoder).__name__   
            }, f)


        #Text encoder/tokenizer saving is different depending on if we're using a larger pretrained model
        if isinstance(self.text_encoder, TransformerModel):
            # Save custom text encoder
            if self.text_encoder is not None:
                self.text_encoder.save_pretrained(os.path.join(save_directory, "text_encoder"))
        else:
            #Save pretrained tokenizer by name, so we can load from huggingface instead of saving a giant local model
            text_encoder_info = {
                "text_encoder_name": self.text_encoder.config.name_or_path,
                "tokenizer_name": self.tokenizer.name_or_path,
            }

            text_encoder_directory = os.path.join(save_directory, "text_encoder")
            os.makedirs(text_encoder_directory, exist_ok=True)

            with open(os.path.join(text_encoder_directory, "loading_info.json"), "w") as f:
                json.dump(text_encoder_info, f)
            


    @classmethod
    def from_pretrained(cls, pretrained_model_path, **kwargs):
        #from diffusers.utils import load_config, load_state_dict
        # Load model_index.json
        #model_index = load_config(pretrained_model_path)

        # Load components manually
        unet_path = os.path.join(pretrained_model_path, "unet")
        unet = UNet2DConditionModel.from_pretrained(unet_path)

        scheduler_path = os.path.join(pretrained_model_path, "scheduler")
        # Have heard that DDIMScheduler might be faster for inference, though not necessarily better
        scheduler = DDPMScheduler.from_pretrained(scheduler_path)

        tokenizer = None
        text_encoder_path = os.path.join(pretrained_model_path, "text_encoder")
        
        if os.path.exists(text_encoder_path):
            #Test for the new saving system, where we save a simple config file
            if os.path.exists(os.path.join(text_encoder_path, "loading_info.json")):
                with open(os.path.join(text_encoder_path, "loading_info.json"), "r") as f:
                    encoder_config = json.load(f)

                text_encoder = AutoModel.from_pretrained(encoder_config['text_encoder_name'], trust_remote_code=True)
                tokenizer = AutoTokenizer.from_pretrained(encoder_config['tokenizer_name'])
            
            #Legacy loading system, loads models directly if the whole thing is saved in the directory
            else:
                try:
                    text_encoder = AutoModel.from_pretrained(text_encoder_path, local_files_only=True, trust_remote_code=True)
                    tokenizer = AutoTokenizer.from_pretrained(text_encoder_path, local_files_only=True)
                except (ValueError, KeyError):
                    text_encoder = TransformerModel.from_pretrained(text_encoder_path)
                    tokenizer = text_encoder.tokenizer
        else:
            text_encoder = None

        # Instantiate your pipeline
        pipeline = cls(
            unet=unet,
            scheduler=scheduler,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            **kwargs,
        )

        #Loads block embeddings if present
        block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
        if os.path.exists(block_embeds_path):
            pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
        else:
            pipeline.block_embeddings = None
        

        # Load supports_negative_prompt flag if present
        config_path = os.path.join(pretrained_model_path, "pipeline_config.json")
        if os.path.exists(config_path):
            with open(config_path, "r") as f:
                config = json.load(f)
            pipeline.supports_negative_prompt = config.get("supports_negative_prompt", False)
            pipeline.supports_pretrained_split = config.get("supports_pretrained_split", False)
        return pipeline

    # --- Handle batching for captions ---
    def _prepare_text_batch(self, text: Optional[str | list[str]], batch_size: int, name: str) -> Optional[list[str]]:
        if text is None:
            return None
        if isinstance(text, str):
            return [text] * batch_size
        if isinstance(text, list):
            if len(text) == 1:
                return text * batch_size
            if len(text) != batch_size:
                raise ValueError(f"{name} list length {len(text)} does not match batch_size {batch_size}")
            return text
        raise ValueError(f"{name} must be a string or list of strings")

    def _prepare_initial_sample(self, 

                                raw_latent_sample: Optional[torch.Tensor],

                                input_scene: Optional[torch.Tensor],

                                batch_size: int, height: int, width: int,

                                generator: Optional[torch.Generator]) -> torch.Tensor:
        """Prepare the initial sample for diffusion."""
 
        sample_shape = (batch_size, self.unet.config.in_channels, height, width)

        if raw_latent_sample is not None:
            if input_scene is not None:
                raise ValueError("Cannot provide both raw_latent_sample and input_scene")
            sample = raw_latent_sample.to(self.device)
            if sample.shape[1] != sample_shape[1]:
                raise ValueError(f"Wrong number of channels in raw_latent_sample: Expected {self.unet.config.in_channels} but got {sample.shape[1]}")
            if sample.shape[0] == 1 and batch_size > 1:
                sample = sample.repeat(batch_size, 1, 1, 1)
            elif sample.shape[0] != batch_size:
                raise ValueError(f"raw_latent_sample batch size {sample.shape[0]} does not match batch_size {batch_size}")
        elif input_scene is not None:
            # input_scene can be (H, W) or (batch_size, H, W)
            scene_tensor = torch.tensor(input_scene, dtype=torch.long, device=self.device)
            if scene_tensor.dim() == 2:
                # (H, W) -> repeat for batch
                scene_tensor = scene_tensor.unsqueeze(0).repeat(batch_size, 1, 1)
            elif scene_tensor.shape[0] == 1 and batch_size > 1:
                scene_tensor = scene_tensor.repeat(batch_size, 1, 1)
            elif scene_tensor.shape[0] != batch_size:
                raise ValueError(f"input_scene batch size {scene_tensor.shape[0]} does not match batch_size {batch_size}")
            # One-hot encode: (batch, H, W, C)
            one_hot = F.one_hot(scene_tensor, num_classes=self.unet.config.in_channels).float()
            # (batch, H, W, C) -> (batch, C, H, W)
            sample = one_hot.permute(0, 3, 1, 2)
        else:
            # Start from random noise
            sample = torch.randn(sample_shape, generator=generator, device=self.device)

        return sample

    def __call__(

        self,

        caption: Optional[str | list[str]] = None,

        negative_prompt: Optional[str | list[str]] = None,

        generator: Optional[torch.Generator] = None,

        num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,

        guidance_scale: float = common_settings.GUIDANCE_SCALE,

        height: int = common_settings.MARIO_HEIGHT,

        width: int = common_settings.MARIO_WIDTH,

        raw_latent_sample: Optional[torch.FloatTensor] = None,

        input_scene: Optional[torch.Tensor] = None,

        output_type: str = "tensor",

        batch_size: int = 1,

        show_progress_bar: bool = True,

    ) -> PipelineOutput:
        """Generate a batch of images based on text input using the diffusion model.



        Args:

            caption: Text description(s) of the desired output. Can be a string or list of strings.

            negative_prompt: Text description(s) of what should not appear in the output. String or list.

            generator: Random number generator for reproducibility.

            num_inference_steps: Number of denoising steps (more = higher quality, slower).

            guidance_scale: How strongly the generation follows the text prompt (higher = stronger).

            height: Height of generated image in tiles.

            width: Width of generated image in tiles.

            raw_latent_sample: Optional starting point for diffusion instead of random noise.

                Must have correct number of channels matching the UNet.

            input_scene: Optional 2D or 3D int tensor where each value corresponds to a tile type.

                Will be converted to one-hot encoding as starting point.

            output_type: Currently only "tensor" is supported.

            batch_size: Number of samples to generate in parallel.



        Returns:

            PipelineOutput containing the generated image tensor (batch_size, ...).

        """

        #       I would like to simplify the code to this, but the AI suggestion didn't work, and 
        #       I did not feel good just pasting it all in. Will need to tackle it bit by bit.

        #        if caption is not None and self.text_encoder is None:
        #            raise ValueError("Text encoder required for conditional generation")
    
        #        self.unet.eval()
        #        if self.text_encoder is not None:
        #            self.text_encoder.to(self.device)
        #            self.text_encoder.eval()
        #
        #        with torch.no_grad():
        #            # Process text inputs
        #            captions = self.prepare_text_batch(caption, batch_size, "caption")
        #            negatives = self.prepare_text_batch(negative_prompt, batch_size, "negative_prompt")
         
        #            # Get embeddings
        #            text_embeddings = self.prepare_embeddings(captions, negatives, batch_size)
        #            
        #            # Set up initial latent state
        #            sample = self.prepare_initial_sample(raw_latent_sample, input_scene, 
        #                                              batch_size, height, width, generator)
           
        #            # Run diffusion process
        #            sample = self.run_diffusion(sample, text_embeddings, num_inference_steps,
        #                                      guidance_scale, generator, show_progress_bar,
        #                                      has_caption=caption is not None,
        #                                      has_negative=negative_prompt is not None)
         
        #            # Format output
        #            if output_type == "tensor":
        #                sample = F.softmax(sample, dim=1)
        #            else:
        #                raise ValueError(f"Unsupported output type: {output_type}")

        #        return PipelineOutput(images=sample)

        # Validate text encoder if we need it
        if caption is not None and self.text_encoder is None:
            raise ValueError("Text encoder is required for conditional generation")

        self.unet.eval()
        if self.text_encoder is not None:
            self.text_encoder.to(self.device)
            self.text_encoder.eval()

        with torch.no_grad():
            captions = self._prepare_text_batch(caption, batch_size, "caption")
            negatives = self._prepare_text_batch(negative_prompt, batch_size, "negative_prompt")

            # --- Prepare text embeddings ---
            if(isinstance(self.text_encoder, TransformerModel)):
                text_embeddings = text_model.get_embeddings(batch_size=batch_size,
                                                            tokenizer=self.text_encoder.tokenizer,
                                                            text_encoder=self.text_encoder,
                                                            captions=captions,
                                                            neg_captions=negatives,
                                                            device=self.device)
            else: #Case for the pre-trained text encoder
                if(self.supports_pretrained_split): #If we have a split flag incorporated
                    text_embeddings = st_helper.get_embeddings_split(batch_size = batch_size,
                                                            tokenizer=self.tokenizer,
                                                            model=self.text_encoder,
                                                            captions=captions,
                                                            neg_captions=negatives,
                                                            device=self.device)
                else:
                    text_embeddings = st_helper.get_embeddings(batch_size = batch_size,
                                                                tokenizer=self.tokenizer,
                                                                model=self.text_encoder,
                                                                captions=captions,
                                                                neg_captions=negatives,
                                                                device=self.device)

                            
            # --- Set up initial latent state ---
            sample = self._prepare_initial_sample(raw_latent_sample, input_scene, 
                                                 batch_size, height, width, generator)

            # --- Set up diffusion process ---
            self.scheduler.set_timesteps(num_inference_steps)

            # Denoising loop
            iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
            for t in iterator:
                # Handle conditional generation
                if captions is not None:
                    if negatives is not None:
                        # Three copies for negative prompt guidance
                        model_input = torch.cat([sample, sample, sample], dim=0)
                    else:
                        # Two copies for standard classifier-free guidance
                        model_input = torch.cat([sample, sample], dim=0)
                else:
                    model_input = sample

                # Predict noise residual
                model_kwargs = {"encoder_hidden_states": text_embeddings}
                noise_pred = self.unet(model_input, t, **model_kwargs).sample

                # Apply guidance
                if captions is not None:
                    if negatives is not None:
                        # Split predictions for negative, unconditional, and text-conditional
                        noise_pred_neg, noise_pred_uncond, noise_pred_text = noise_pred.chunk(3)
                        noise_pred_guided = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
                        noise_pred = noise_pred_guided - guidance_scale * (noise_pred_neg - noise_pred_uncond)
                    else:
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # Compute previous sample: x_{t-1} = scheduler(x_t, noise_pred)
                sample = self.scheduler.step(noise_pred, t, sample, generator=generator).prev_sample

            # Convert to output format
            if output_type == "tensor":
                if self.block_embeddings is not None:
                    sample = get_scene_from_embeddings(sample, self.block_embeddings)
                else:
                    # Apply softmax to get probabilities for each tile type
                    sample = F.softmax(sample, dim=1)
                    sample = sample.detach().cpu() 
            else:
                raise ValueError(f"Unsupported output type: {output_type}")

        return PipelineOutput(images=sample)

    def print_unet_architecture(self):
        """Prints the architecture of the UNet model."""
        print(self.unet)

    def print_text_encoder_architecture(self):
        """Prints the architecture of the text encoder model, if it exists."""
        if self.text_encoder is not None:
            print(self.text_encoder)
        else:
            print("No text encoder is set.")

    def save_unet_architecture_pdf(self, height, width, filename="unet_architecture", batch_size=1, device=None):
        """

        Have to separately install torchview for this to work



        Saves a visualization of the UNet architecture as a PDF using torchview.

        Args:

            height: Height of the dummy input.

            width: Width of the dummy input.

            filename: Output PDF filename.

            batch_size: Batch size for dummy input.

            device: Device to run the dummy input on (defaults to pipeline device).

        """
        from torchview import draw_graph
        import graphviz
        
        if device is None:
            device = self.device if hasattr(self, 'device') else 'cpu'
        in_channels = self.unet.config.in_channels if hasattr(self.unet, 'config') else 1
        sample_shape = tuple([batch_size, in_channels, height, width])

        dummy_x = torch.randn(size=sample_shape, device=device)
        dummy_t = torch.tensor([0] * batch_size, dtype=torch.long, device=device)

        # Prepare dummy text embedding (match what your UNet expects)
        if hasattr(self.unet, 'config') and hasattr(self.unet.config, 'cross_attention_dim'):
            cross_attention_dim = self.unet.config.cross_attention_dim
        else:
            cross_attention_dim = 128  # fallback
        encoder_hidden_states = torch.randn(batch_size, 1, cross_attention_dim, device=device)

        self.unet.eval()
        inputs = (dummy_x, dummy_t, encoder_hidden_states)
        #self.unet.down_blocks = self.unet.down_blocks[:2]

        graph = draw_graph(
            model=self.unet,
            input_data=inputs,
            expand_nested=False,
            #enable_output_shape=True,   
            #roll_out="nested",
            depth=1
        )
        #graph.visual_graph.engine = "neato"
        graph.visual_graph.attr(#rankdir="LR",
                                nodesep="0.1",      # decrease space between nodes in the same rank (default ~0.25)
                                ranksep="0.2",       # decrease space between ranks (default ~0.5)
                                concentrate="true"  # merge edges between nodes in the same rank
                            )
        graph.visual_graph.node_attr.update(
            shape="rectangle",
            width="1.5",   # narrow width
            height="0.5"  # taller height to make vertical rectangles
            #fixedsize="true"
        )
        
        graph.visual_graph.render(filename, format='pdf', cleanup=False)  # Cleanup removes intermediate files
        graph.visual_graph.save('unet_architecture.dot')

        # Save the graph to a PDF file
        print(f"UNet architecture saved to {filename}")