schrum2 commited on
Commit
179f3e8
·
1 Parent(s): 0441379

Revert all to basic model files

Browse files
model_index.json CHANGED
@@ -6,7 +6,7 @@
6
  "DDPMScheduler"
7
  ],
8
  "text_encoder": [
9
- "text_model",
10
  "TransformerModel"
11
  ],
12
  "tokenizer": [
@@ -16,6 +16,5 @@
16
  "unet": [
17
  "diffusers",
18
  "UNet2DConditionModel"
19
- ],
20
- "pipeline": "TextConditionalDDPMPipeline"
21
  }
 
6
  "DDPMScheduler"
7
  ],
8
  "text_encoder": [
9
+ "models.text_model",
10
  "TransformerModel"
11
  ],
12
  "tokenizer": [
 
16
  "unet": [
17
  "diffusers",
18
  "UNet2DConditionModel"
19
+ ]
 
20
  }
models/latent_diffusion_pipeline.py DELETED
@@ -1,99 +0,0 @@
1
- from diffusers import DDPMPipeline
2
- import torch
3
- import torch.nn.functional as F
4
- from typing import Optional, Union, List, Tuple
5
- from diffusers.utils.torch_utils import randn_tensor
6
- from diffusers.pipelines.ddpm.pipeline_ddpm import ImagePipelineOutput
7
- import common_settings as common_settings
8
- import os
9
- import json
10
- from general_training_helper import get_scene_from_embeddings
11
-
12
- class UnconditionalDDPMPipeline(DDPMPipeline):
13
- def __init__(self, unet, scheduler, block_embeddings=None):
14
- super().__init__(unet, scheduler)
15
-
16
- self.block_embeddings = block_embeddings
17
-
18
-
19
- def save_pretrained(self, save_directory):
20
- os.makedirs(save_directory, exist_ok=True)
21
- super().save_pretrained(save_directory)
22
- # Save block_embeddings tensor if it exists
23
- if self.block_embeddings is not None:
24
- torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))
25
-
26
- @classmethod
27
- def from_pretrained(cls, pretrained_model_path, **kwargs):
28
- pipeline = super().from_pretrained(pretrained_model_path, **kwargs)
29
- # Load block_embeddings tensor if it exists
30
- block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
31
- if os.path.exists(block_embeds_path):
32
- pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
33
- else:
34
- pipeline.block_embeddings = None
35
- return pipeline
36
-
37
-
38
-
39
- def give_sprite_scaling_factors(self, sprite_scaling_factors):
40
- """
41
- Set the sprite scaling factors for the pipeline.
42
- This is used to apply per-sprite temperature scaling during inference.
43
- """
44
- self.sprite_scaling_factors = sprite_scaling_factors
45
-
46
- def __call__(
47
- self,
48
- batch_size: int = 1,
49
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
50
- num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,
51
- output_type: Optional[str] = "tensor",
52
- return_dict: bool = True,
53
- height: int = common_settings.MARIO_HEIGHT, width: int = common_settings.MARIO_WIDTH,
54
- latents: Optional[torch.FloatTensor] = None,
55
- show_progress_bar=True,
56
- ) -> Union[ImagePipelineOutput, Tuple]:
57
-
58
- self.unet.eval()
59
- with torch.no_grad():
60
-
61
- if latents is not None:
62
- image = latents.to(self.device)
63
- else:
64
- image_shape = (
65
- batch_size,
66
- self.unet.config.in_channels,
67
- height,
68
- width
69
- )
70
-
71
- image = torch.randn(image_shape, generator=generator, device=self.device)
72
-
73
- self.scheduler.set_timesteps(num_inference_steps)
74
-
75
- iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
76
- for t in iterator:
77
- #print(image.shape)
78
- model_output = self.unet(image, t).sample
79
- image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
80
-
81
- # Apply per-sprite temperature scaling if enabled
82
- if hasattr(self,"sprite_scaling_factors") and self.sprite_scaling_factors is not None:
83
- image = image / self.sprite_scaling_factors.view(1, -1, 1, 1)
84
-
85
-
86
- if self.block_embeddings is not None:
87
- image = get_scene_from_embeddings(image, self.block_embeddings)
88
- else:
89
- image = F.softmax(image, dim=1)
90
- image = image.detach().cpu()
91
-
92
- if not return_dict:
93
- return (image,)
94
-
95
- return ImagePipelineOutput(images=image)
96
-
97
- def print_unet_architecture(self):
98
- """Prints the architecture of the UNet model."""
99
- print(self.unet)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/pipeline_loader.py DELETED
@@ -1,32 +0,0 @@
1
- from models.text_diffusion_pipeline import TextConditionalDDPMPipeline
2
- from models.latent_diffusion_pipeline import UnconditionalDDPMPipeline
3
- import os
4
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
5
- from huggingface_hub import snapshot_download
6
-
7
-
8
- def get_pipeline(model_path):
9
- # If model_path is a local directory, use the original logic
10
- if os.path.isdir(model_path):
11
- #Diffusion models
12
- if os.path.exists(os.path.join(model_path, "unet")):
13
- if os.path.exists(os.path.join(model_path, "text_encoder")):
14
- #If it has a text encoder and a unet, it's text conditional diffusion
15
- pipe = TextConditionalDDPMPipeline.from_pretrained(model_path)
16
- else:
17
- #If it has no text encoder, use the unconditional diffusion model
18
- pipe = UnconditionalDDPMPipeline.from_pretrained(model_path)
19
- else:
20
- # For HF Hub models, download first then load locally
21
- print(f"Downloading model {model_path}...")
22
- local_path = snapshot_download(repo_id=model_path, cache_dir="./temp_model_cache")
23
-
24
- # Check what components exist
25
- has_text_encoder = os.path.exists(os.path.join(local_path, "text_encoder"))
26
-
27
- if has_text_encoder:
28
- pipe = TextConditionalDDPMPipeline.from_pretrained(local_path)
29
- else:
30
- pipe = UnconditionalDDPMPipeline.from_pretrained(local_path)
31
-
32
- return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/text_diffusion_pipeline.py DELETED
@@ -1,446 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from typing import NamedTuple, Optional
4
- import os
5
- from diffusers import DDPMPipeline, UNet2DConditionModel, DDPMScheduler, DiffusionPipeline
6
- import json
7
- # Running the main at the end of this requires messing with this import
8
- from text_encoder.text_model import TransformerModel
9
- import torch
10
- import torch.nn.functional as F
11
- from transformers import AutoTokenizer, AutoModel
12
- import common_settings as common_settings
13
- import sentence_transformers_helper as st_helper
14
- import text_encoder.text_model as text_model
15
- from general_training_helper import get_scene_from_embeddings
16
-
17
- class PipelineOutput(NamedTuple):
18
- images: torch.Tensor
19
-
20
-
21
- # Create a custom pipeline for text-conditional generation
22
- class TextConditionalDDPMPipeline(DDPMPipeline):
23
- def __init__(self, unet, scheduler, text_encoder=None, tokenizer=None, supports_pretrained_split=False, block_embeddings=None):
24
- # Call parent class init normally
25
- super().__init__(unet=unet, scheduler=scheduler)
26
-
27
- self.text_encoder = text_encoder
28
- self.tokenizer = tokenizer
29
- self.supports_negative_prompt = hasattr(unet, 'negative_prompt_support') and unet.negative_prompt_support
30
- self.supports_pretrained_split = supports_pretrained_split
31
- self.block_embeddings = block_embeddings
32
-
33
- if self.tokenizer is None and self.text_encoder is not None:
34
- # Use the tokenizer from the text encoder if not provided
35
- if hasattr(self.text_encoder, 'tokenizer'):
36
- self.tokenizer = self.text_encoder.tokenizer
37
-
38
- # Register additional modules if they exist
39
- additional_modules = {}
40
- if self.text_encoder is not None:
41
- additional_modules['text_encoder'] = self.text_encoder
42
- if self.tokenizer is not None:
43
- additional_modules['tokenizer'] = self.tokenizer
44
-
45
- if additional_modules:
46
- self.register_modules(**additional_modules)
47
-
48
- # Override the to() method to ensure text_encoder is moved to the correct device
49
- def to(self, device=None, dtype=None):
50
- # Call the parent's to() method first
51
- pipeline = super().to(device, dtype)
52
-
53
- # Additionally move the text_encoder to the device
54
- if self.text_encoder is not None:
55
- self.text_encoder.to(device)
56
-
57
- return pipeline
58
-
59
- def save_pretrained(self, save_directory):
60
- os.makedirs(save_directory, exist_ok=True)
61
- super().save_pretrained(save_directory) # saves UNet and scheduler
62
-
63
- # Save block_embeddings tensor if it exists
64
- if self.block_embeddings is not None:
65
- torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))
66
-
67
- # Save supports_negative_prompt and supports_pretrained_split flags
68
- with open(os.path.join(save_directory, "pipeline_config.json"), "w") as f:
69
- json.dump({
70
- "supports_negative_prompt": self.supports_negative_prompt,
71
- "supports_pretrained_split": self.supports_pretrained_split,
72
- "text_encoder_type": type(self.text_encoder).__name__
73
- }, f)
74
-
75
-
76
- #Text encoder/tokenizer saving is different depending on if we're using a larger pretrained model
77
- if isinstance(self.text_encoder, TransformerModel):
78
- # Save custom text encoder
79
- if self.text_encoder is not None:
80
- self.text_encoder.save_pretrained(os.path.join(save_directory, "text_encoder"))
81
- else:
82
- #Save pretrained tokenizer by name, so we can load from huggingface instead of saving a giant local model
83
- text_encoder_info = {
84
- "text_encoder_name": self.text_encoder.config.name_or_path,
85
- "tokenizer_name": self.tokenizer.name_or_path,
86
- }
87
-
88
- text_encoder_directory = os.path.join(save_directory, "text_encoder")
89
- os.makedirs(text_encoder_directory, exist_ok=True)
90
-
91
- with open(os.path.join(text_encoder_directory, "loading_info.json"), "w") as f:
92
- json.dump(text_encoder_info, f)
93
-
94
-
95
-
96
- @classmethod
97
- def from_pretrained(cls, pretrained_model_path, **kwargs):
98
- #from diffusers.utils import load_config, load_state_dict
99
- # Load model_index.json
100
- #model_index = load_config(pretrained_model_path)
101
-
102
- # Load components manually
103
- unet_path = os.path.join(pretrained_model_path, "unet")
104
- unet = UNet2DConditionModel.from_pretrained(unet_path)
105
-
106
- scheduler_path = os.path.join(pretrained_model_path, "scheduler")
107
- # Have heard that DDIMScheduler might be faster for inference, though not necessarily better
108
- scheduler = DDPMScheduler.from_pretrained(scheduler_path)
109
-
110
- tokenizer = None
111
- text_encoder_path = os.path.join(pretrained_model_path, "text_encoder")
112
-
113
- if os.path.exists(text_encoder_path):
114
- #Test for the new saving system, where we save a simple config file
115
- if os.path.exists(os.path.join(text_encoder_path, "loading_info.json")):
116
- with open(os.path.join(text_encoder_path, "loading_info.json"), "r") as f:
117
- encoder_config = json.load(f)
118
-
119
- text_encoder = AutoModel.from_pretrained(encoder_config['text_encoder_name'], trust_remote_code=True)
120
- tokenizer = AutoTokenizer.from_pretrained(encoder_config['tokenizer_name'])
121
-
122
- #Legacy loading system, loads models directly if the whole thing is saved in the directory
123
- else:
124
- try:
125
- text_encoder = AutoModel.from_pretrained(text_encoder_path, local_files_only=True, trust_remote_code=True)
126
- tokenizer = AutoTokenizer.from_pretrained(text_encoder_path, local_files_only=True)
127
- except (ValueError, KeyError):
128
- text_encoder = TransformerModel.from_pretrained(text_encoder_path)
129
- tokenizer = text_encoder.tokenizer
130
- else:
131
- text_encoder = None
132
-
133
- # Instantiate your pipeline
134
- pipeline = cls(
135
- unet=unet,
136
- scheduler=scheduler,
137
- text_encoder=text_encoder,
138
- tokenizer=tokenizer,
139
- **kwargs,
140
- )
141
-
142
- #Loads block embeddings if present
143
- block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
144
- if os.path.exists(block_embeds_path):
145
- pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
146
- else:
147
- pipeline.block_embeddings = None
148
-
149
-
150
- # Load supports_negative_prompt flag if present
151
- config_path = os.path.join(pretrained_model_path, "pipeline_config.json")
152
- if os.path.exists(config_path):
153
- with open(config_path, "r") as f:
154
- config = json.load(f)
155
- pipeline.supports_negative_prompt = config.get("supports_negative_prompt", False)
156
- pipeline.supports_pretrained_split = config.get("supports_pretrained_split", False)
157
- return pipeline
158
-
159
- # --- Handle batching for captions ---
160
- def _prepare_text_batch(self, text: Optional[str | list[str]], batch_size: int, name: str) -> Optional[list[str]]:
161
- if text is None:
162
- return None
163
- if isinstance(text, str):
164
- return [text] * batch_size
165
- if isinstance(text, list):
166
- if len(text) == 1:
167
- return text * batch_size
168
- if len(text) != batch_size:
169
- raise ValueError(f"{name} list length {len(text)} does not match batch_size {batch_size}")
170
- return text
171
- raise ValueError(f"{name} must be a string or list of strings")
172
-
173
- def _prepare_initial_sample(self,
174
- raw_latent_sample: Optional[torch.Tensor],
175
- input_scene: Optional[torch.Tensor],
176
- batch_size: int, height: int, width: int,
177
- generator: Optional[torch.Generator]) -> torch.Tensor:
178
- """Prepare the initial sample for diffusion."""
179
-
180
- sample_shape = (batch_size, self.unet.config.in_channels, height, width)
181
-
182
- if raw_latent_sample is not None:
183
- if input_scene is not None:
184
- raise ValueError("Cannot provide both raw_latent_sample and input_scene")
185
- sample = raw_latent_sample.to(self.device)
186
- if sample.shape[1] != sample_shape[1]:
187
- raise ValueError(f"Wrong number of channels in raw_latent_sample: Expected {self.unet.config.in_channels} but got {sample.shape[1]}")
188
- if sample.shape[0] == 1 and batch_size > 1:
189
- sample = sample.repeat(batch_size, 1, 1, 1)
190
- elif sample.shape[0] != batch_size:
191
- raise ValueError(f"raw_latent_sample batch size {sample.shape[0]} does not match batch_size {batch_size}")
192
- elif input_scene is not None:
193
- # input_scene can be (H, W) or (batch_size, H, W)
194
- scene_tensor = torch.tensor(input_scene, dtype=torch.long, device=self.device)
195
- if scene_tensor.dim() == 2:
196
- # (H, W) -> repeat for batch
197
- scene_tensor = scene_tensor.unsqueeze(0).repeat(batch_size, 1, 1)
198
- elif scene_tensor.shape[0] == 1 and batch_size > 1:
199
- scene_tensor = scene_tensor.repeat(batch_size, 1, 1)
200
- elif scene_tensor.shape[0] != batch_size:
201
- raise ValueError(f"input_scene batch size {scene_tensor.shape[0]} does not match batch_size {batch_size}")
202
- # One-hot encode: (batch, H, W, C)
203
- one_hot = F.one_hot(scene_tensor, num_classes=self.unet.config.in_channels).float()
204
- # (batch, H, W, C) -> (batch, C, H, W)
205
- sample = one_hot.permute(0, 3, 1, 2)
206
- else:
207
- # Start from random noise
208
- sample = torch.randn(sample_shape, generator=generator, device=self.device)
209
-
210
- return sample
211
-
212
- def __call__(
213
- self,
214
- caption: Optional[str | list[str]] = None,
215
- negative_prompt: Optional[str | list[str]] = None,
216
- generator: Optional[torch.Generator] = None,
217
- num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,
218
- guidance_scale: float = common_settings.GUIDANCE_SCALE,
219
- height: int = common_settings.MARIO_HEIGHT,
220
- width: int = common_settings.MARIO_WIDTH,
221
- raw_latent_sample: Optional[torch.FloatTensor] = None,
222
- input_scene: Optional[torch.Tensor] = None,
223
- output_type: str = "tensor",
224
- batch_size: int = 1,
225
- show_progress_bar: bool = True,
226
- ) -> PipelineOutput:
227
- """Generate a batch of images based on text input using the diffusion model.
228
-
229
- Args:
230
- caption: Text description(s) of the desired output. Can be a string or list of strings.
231
- negative_prompt: Text description(s) of what should not appear in the output. String or list.
232
- generator: Random number generator for reproducibility.
233
- num_inference_steps: Number of denoising steps (more = higher quality, slower).
234
- guidance_scale: How strongly the generation follows the text prompt (higher = stronger).
235
- height: Height of generated image in tiles.
236
- width: Width of generated image in tiles.
237
- raw_latent_sample: Optional starting point for diffusion instead of random noise.
238
- Must have correct number of channels matching the UNet.
239
- input_scene: Optional 2D or 3D int tensor where each value corresponds to a tile type.
240
- Will be converted to one-hot encoding as starting point.
241
- output_type: Currently only "tensor" is supported.
242
- batch_size: Number of samples to generate in parallel.
243
-
244
- Returns:
245
- PipelineOutput containing the generated image tensor (batch_size, ...).
246
- """
247
-
248
- # I would like to simplify the code to this, but the AI suggestion didn't work, and
249
- # I did not feel good just pasting it all in. Will need to tackle it bit by bit.
250
-
251
- # if caption is not None and self.text_encoder is None:
252
- # raise ValueError("Text encoder required for conditional generation")
253
-
254
- # self.unet.eval()
255
- # if self.text_encoder is not None:
256
- # self.text_encoder.to(self.device)
257
- # self.text_encoder.eval()
258
- #
259
- # with torch.no_grad():
260
- # # Process text inputs
261
- # captions = self.prepare_text_batch(caption, batch_size, "caption")
262
- # negatives = self.prepare_text_batch(negative_prompt, batch_size, "negative_prompt")
263
-
264
- # # Get embeddings
265
- # text_embeddings = self.prepare_embeddings(captions, negatives, batch_size)
266
- #
267
- # # Set up initial latent state
268
- # sample = self.prepare_initial_sample(raw_latent_sample, input_scene,
269
- # batch_size, height, width, generator)
270
-
271
- # # Run diffusion process
272
- # sample = self.run_diffusion(sample, text_embeddings, num_inference_steps,
273
- # guidance_scale, generator, show_progress_bar,
274
- # has_caption=caption is not None,
275
- # has_negative=negative_prompt is not None)
276
-
277
- # # Format output
278
- # if output_type == "tensor":
279
- # sample = F.softmax(sample, dim=1)
280
- # else:
281
- # raise ValueError(f"Unsupported output type: {output_type}")
282
-
283
- # return PipelineOutput(images=sample)
284
-
285
- # Validate text encoder if we need it
286
- if caption is not None and self.text_encoder is None:
287
- raise ValueError("Text encoder is required for conditional generation")
288
-
289
- self.unet.eval()
290
- if self.text_encoder is not None:
291
- self.text_encoder.to(self.device)
292
- self.text_encoder.eval()
293
-
294
- with torch.no_grad():
295
- captions = self._prepare_text_batch(caption, batch_size, "caption")
296
- negatives = self._prepare_text_batch(negative_prompt, batch_size, "negative_prompt")
297
-
298
- # --- Prepare text embeddings ---
299
- if(isinstance(self.text_encoder, TransformerModel)):
300
- text_embeddings = text_model.get_embeddings(batch_size=batch_size,
301
- tokenizer=self.text_encoder.tokenizer,
302
- text_encoder=self.text_encoder,
303
- captions=captions,
304
- neg_captions=negatives,
305
- device=self.device)
306
- else: #Case for the pre-trained text encoder
307
- if(self.supports_pretrained_split): #If we have a split flag incorporated
308
- text_embeddings = st_helper.get_embeddings_split(batch_size = batch_size,
309
- tokenizer=self.tokenizer,
310
- model=self.text_encoder,
311
- captions=captions,
312
- neg_captions=negatives,
313
- device=self.device)
314
- else:
315
- text_embeddings = st_helper.get_embeddings(batch_size = batch_size,
316
- tokenizer=self.tokenizer,
317
- model=self.text_encoder,
318
- captions=captions,
319
- neg_captions=negatives,
320
- device=self.device)
321
-
322
-
323
- # --- Set up initial latent state ---
324
- sample = self._prepare_initial_sample(raw_latent_sample, input_scene,
325
- batch_size, height, width, generator)
326
-
327
- # --- Set up diffusion process ---
328
- self.scheduler.set_timesteps(num_inference_steps)
329
-
330
- # Denoising loop
331
- iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
332
- for t in iterator:
333
- # Handle conditional generation
334
- if captions is not None:
335
- if negatives is not None:
336
- # Three copies for negative prompt guidance
337
- model_input = torch.cat([sample, sample, sample], dim=0)
338
- else:
339
- # Two copies for standard classifier-free guidance
340
- model_input = torch.cat([sample, sample], dim=0)
341
- else:
342
- model_input = sample
343
-
344
- # Predict noise residual
345
- model_kwargs = {"encoder_hidden_states": text_embeddings}
346
- noise_pred = self.unet(model_input, t, **model_kwargs).sample
347
-
348
- # Apply guidance
349
- if captions is not None:
350
- if negatives is not None:
351
- # Split predictions for negative, unconditional, and text-conditional
352
- noise_pred_neg, noise_pred_uncond, noise_pred_text = noise_pred.chunk(3)
353
- noise_pred_guided = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
354
- noise_pred = noise_pred_guided - guidance_scale * (noise_pred_neg - noise_pred_uncond)
355
- else:
356
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
357
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
358
-
359
- # Compute previous sample: x_{t-1} = scheduler(x_t, noise_pred)
360
- sample = self.scheduler.step(noise_pred, t, sample, generator=generator).prev_sample
361
-
362
- # Convert to output format
363
- if output_type == "tensor":
364
- if self.block_embeddings is not None:
365
- sample = get_scene_from_embeddings(sample, self.block_embeddings)
366
- else:
367
- # Apply softmax to get probabilities for each tile type
368
- sample = F.softmax(sample, dim=1)
369
- sample = sample.detach().cpu()
370
- else:
371
- raise ValueError(f"Unsupported output type: {output_type}")
372
-
373
- return PipelineOutput(images=sample)
374
-
375
- def print_unet_architecture(self):
376
- """Prints the architecture of the UNet model."""
377
- print(self.unet)
378
-
379
- def print_text_encoder_architecture(self):
380
- """Prints the architecture of the text encoder model, if it exists."""
381
- if self.text_encoder is not None:
382
- print(self.text_encoder)
383
- else:
384
- print("No text encoder is set.")
385
-
386
- def save_unet_architecture_pdf(self, height, width, filename="unet_architecture", batch_size=1, device=None):
387
- """
388
- Have to separately install torchview for this to work
389
-
390
- Saves a visualization of the UNet architecture as a PDF using torchview.
391
- Args:
392
- height: Height of the dummy input.
393
- width: Width of the dummy input.
394
- filename: Output PDF filename.
395
- batch_size: Batch size for dummy input.
396
- device: Device to run the dummy input on (defaults to pipeline device).
397
- """
398
- from torchview import draw_graph
399
- import graphviz
400
-
401
- if device is None:
402
- device = self.device if hasattr(self, 'device') else 'cpu'
403
- in_channels = self.unet.config.in_channels if hasattr(self.unet, 'config') else 1
404
- sample_shape = tuple([batch_size, in_channels, height, width])
405
-
406
- dummy_x = torch.randn(size=sample_shape, device=device)
407
- dummy_t = torch.tensor([0] * batch_size, dtype=torch.long, device=device)
408
-
409
- # Prepare dummy text embedding (match what your UNet expects)
410
- if hasattr(self.unet, 'config') and hasattr(self.unet.config, 'cross_attention_dim'):
411
- cross_attention_dim = self.unet.config.cross_attention_dim
412
- else:
413
- cross_attention_dim = 128 # fallback
414
- encoder_hidden_states = torch.randn(batch_size, 1, cross_attention_dim, device=device)
415
-
416
- self.unet.eval()
417
- inputs = (dummy_x, dummy_t, encoder_hidden_states)
418
- #self.unet.down_blocks = self.unet.down_blocks[:2]
419
-
420
- graph = draw_graph(
421
- model=self.unet,
422
- input_data=inputs,
423
- expand_nested=False,
424
- #enable_output_shape=True,
425
- #roll_out="nested",
426
- depth=1
427
- )
428
- #graph.visual_graph.engine = "neato"
429
- graph.visual_graph.attr(#rankdir="LR",
430
- nodesep="0.1", # decrease space between nodes in the same rank (default ~0.25)
431
- ranksep="0.2", # decrease space between ranks (default ~0.5)
432
- concentrate="true" # merge edges between nodes in the same rank
433
- )
434
- graph.visual_graph.node_attr.update(
435
- shape="rectangle",
436
- width="1.5", # narrow width
437
- height="0.5" # taller height to make vertical rectangles
438
- #fixedsize="true"
439
- )
440
-
441
- graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
442
- graph.visual_graph.save('unet_architecture.dot')
443
-
444
- # Save the graph to a PDF file
445
- print(f"UNet architecture saved to {filename}")
446
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
text_encoder/model_index.json DELETED
@@ -1,20 +0,0 @@
1
- {
2
- "_class_name": "TextConditionalDDPMPipeline",
3
- "_diffusers_version": "0.32.2",
4
- "scheduler": [
5
- "diffusers",
6
- "DDPMScheduler"
7
- ],
8
- "text_encoder": [
9
- "text_model",
10
- "TransformerModel"
11
- ],
12
- "tokenizer": [
13
- "Tokenizer"
14
- ],
15
- "unet": [
16
- "diffusers",
17
- "UNet2DConditionModel"
18
- ],
19
- "pipeline": "TextConditionalDDPMPipeline"
20
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
text_encoder/text_model.py DELETED
@@ -1,206 +0,0 @@
1
- import argparse
2
- from xml.parsers.expat import model
3
- import torch
4
- import torch.nn as nn
5
- import math
6
- import os
7
- import json
8
- from safetensors.torch import save_file, load_file
9
- from tokenizer.tokenizer import Tokenizer
10
-
11
- def get_embeddings(batch_size, tokenizer, text_encoder, captions=None, neg_captions=None, device='cpu'):
12
- max_length = text_encoder.max_seq_length
13
- empty_ids = encode_token_captions([""] * batch_size, tokenizer, max_length, device=device)
14
- embeddings = text_encoder.get_embeddings(empty_ids)
15
-
16
- if(captions is not None):
17
- caption_ids = encode_token_captions(captions, tokenizer, max_length, device=device)
18
- caption_embeddings = text_encoder.get_embeddings(caption_ids)
19
- embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
20
-
21
- if(neg_captions is not None):
22
- neg_ids = encode_token_captions(neg_captions, tokenizer, max_length, device=device)
23
- neg_embeddings = text_encoder.get_embeddings(neg_ids)
24
- embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
25
-
26
- return embeddings.to(device)
27
-
28
- def encode_token_captions(captions, tokenizer, max_length, device='cpu'):
29
- caption_ids = []
30
- for caption in captions:
31
- tokens = tokenizer.encode(caption)
32
- caption_tokens = tokenizer.pad_sequence(tokens, max_length)
33
- caption_ids.append(torch.tensor(caption_tokens, dtype=torch.long).unsqueeze(0))
34
- return torch.cat(caption_ids, dim=0).to(device)
35
-
36
-
37
-
38
-
39
-
40
-
41
-
42
-
43
-
44
- # Transformer model for MLM training
45
-
46
- class TransformerModel(nn.Module):
47
- def __init__(self, vocab_size, embedding_dim, hidden_dim, tokenizer=None, num_heads=8, num_layers=4, max_seq_length=100):
48
- super().__init__()
49
- self.embedding_dim = embedding_dim
50
- self.vocab_size = vocab_size
51
- self.hidden_dim = hidden_dim
52
- self.num_heads = num_heads
53
- self.num_layers = num_layers
54
- self.max_seq_length = max_seq_length
55
-
56
- self.embedding = nn.Embedding(vocab_size, embedding_dim)
57
- self.positional_encoding = self.create_positional_encoding(max_seq_length, embedding_dim)
58
-
59
- encoder_layers = nn.TransformerEncoderLayer(
60
- d_model=embedding_dim,
61
- nhead=num_heads,
62
- dim_feedforward=hidden_dim,
63
- batch_first=True
64
- )
65
- self.transformer = nn.TransformerEncoder(encoder_layers, num_layers)
66
- self.fc = nn.Linear(embedding_dim, vocab_size)
67
-
68
- self.tokenizer = tokenizer
69
-
70
- def create_positional_encoding(self, max_seq_length, embedding_dim):
71
- # The implementation uses a sinusoidal positional encoding, which creates a unique pattern for each position in the sequence.
72
- # The frequencies create unique values, the sin/cos bounds values
73
- position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
74
- # Creates a set of divisors that create different frequencies
75
- div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
76
- pe = torch.zeros(max_seq_length, embedding_dim)
77
- # Even dimensions use sin, odd dimensions use cos
78
- pe[:, 0::2] = torch.sin(position * div_term)
79
- pe[:, 1::2] = torch.cos(position * div_term)
80
- return pe.unsqueeze(0)
81
-
82
- def get_embeddings(self, x):
83
- """ This gets the actual latent embedding vectors """
84
- # Ensure positional encoding is on the same device as input
85
- pe = self.positional_encoding[:, :x.size(1), :].to(x.device)
86
- # Embed input and add positional encoding
87
- embedded = self.embedding(x) + pe
88
- return self.transformer(embedded)
89
-
90
- def forward(self, x):
91
- """ This gets the token within the vocabulary """
92
- transformer_out = self.get_embeddings(x)
93
- # Project to vocabulary size
94
- return self.fc(transformer_out)
95
-
96
- def save_pretrained(self, save_directory):
97
- os.makedirs(save_directory, exist_ok=True)
98
-
99
- config = {
100
- "vocab_size": self.vocab_size,
101
- "embedding_dim": self.embedding_dim,
102
- "hidden_dim": self.hidden_dim,
103
- "num_heads": self.num_heads,
104
- "num_layers": self.num_layers,
105
- "max_seq_length": self.max_seq_length,
106
- }
107
- with open(os.path.join(save_directory, "config.json"), "w") as f:
108
- json.dump(config, f)
109
-
110
- # Save model weights
111
- save_file(self.state_dict(), os.path.join(save_directory, "model.safetensors"))
112
-
113
- # Save tokenizer if present
114
- if self.tokenizer is not None:
115
- self.tokenizer.save(os.path.join(save_directory, "tokenizer.pkl"))
116
-
117
- @classmethod
118
- def from_pretrained(cls, load_directory):
119
- with open(os.path.join(load_directory, "config.json")) as f:
120
- config = json.load(f)
121
-
122
- model = cls(**config)
123
-
124
- # Load weights
125
- state_dict = load_file(os.path.join(load_directory, "model.safetensors"))
126
- model.load_state_dict(state_dict)
127
-
128
- # Load tokenizer if available
129
- tokenizer_path = os.path.join(load_directory, "tokenizer.pkl")
130
- if os.path.exists(tokenizer_path):
131
- tokenizer = Tokenizer()
132
- tokenizer.load(tokenizer_path)
133
- model.tokenizer = tokenizer
134
-
135
- return model
136
-
137
- def print_architecture(self, inputs=None):
138
- parser = argparse.ArgumentParser()
139
- parser.add_argument("--model_path", type=str, required=True, help="Path to trained transformer model")
140
- parser.add_argument("--json", type=str, default="SMB1_LevelsAndCaptions-regular-test.json", help="Path to dataset json file")
141
- parser.add_argument("--num_samples", type=int, default=10, help="Number of captions to evaluate")
142
- parser.add_argument("--mask_prob", type=float, default=0.15, help="Probability of masking each token")
143
-
144
- parser.add_argument("--compare_checkpoints", action="store_true", default=False, help="Run comparison across all model checkpoints")
145
- args = parser.parse_args()
146
-
147
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148
- model = TransformerModel.from_pretrained(args.model_path).to(device)
149
- print(f"Loaded model from {args.model_path}")
150
-
151
- import os
152
- import re
153
- import json
154
- import matplotlib.pyplot as plt
155
- from torchview import draw_graph
156
- import graphviz
157
-
158
- graph = draw_graph(
159
- model=model,
160
- input_data=inputs,
161
- expand_nested=False,
162
- #enable_output_shape=True,
163
- #roll_out="nested",
164
- depth=1
165
- )
166
-
167
- # Save plot
168
- filename = 'mlm_architecture'
169
- graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
170
- #graph.visual_graph.save('unet_architecture.dot')
171
-
172
- def save_architecture_pdf(self, filename="transformer_architecture.pdf", input_length=32):
173
- """Save a visualization of the model architecture as a PDF using torchview."""
174
- try:
175
- from torchview import draw_graph
176
- except ImportError:
177
- raise ImportError("torchview is required for model visualization. Install with 'pip install torchview'.")
178
- import torch
179
- import os
180
- # Create a dummy input of the correct type for the model
181
- captions = ["full floor. two coins. one pipe.", "floor with two gaps. one cannon. many enemies."]
182
- tensor = encode_token_captions(captions, self.tokenizer, self.max_seq_length, device=next(self.parameters()).device)
183
- input_length = tensor.size(1) if tensor.dim() > 1 else self.max_seq_length
184
-
185
- num_tokens_list = [len(self.tokenizer.encode(c)) for c in captions]
186
- input_length = max(num_tokens_list) if num_tokens_list else input_length
187
- dummy_input = torch.zeros((1, input_length), dtype=torch.long, device=next(self.parameters()).device)
188
-
189
- # Draw the graph and save as PNG
190
- graph = draw_graph(self, input_data=dummy_input, expand_nested=True, save_graph=True, filename=filename.replace('.pdf',''), directory=".", depth=2)
191
- png_file = filename.replace('.pdf', '.png')
192
- # Convert PNG to PDF
193
- if os.path.exists(png_file):
194
- try:
195
- from PIL import Image
196
- im = Image.open(png_file)
197
- im.save(filename, "PDF", resolution=100.0)
198
- print(f"Saved architecture PDF to {filename}")
199
- # Optionally, remove the PNG file
200
- os.remove(png_file)
201
- except ImportError:
202
- print(f"PIL not installed. Architecture saved as PNG: {png_file}")
203
- except Exception as e:
204
- print(f"Could not convert PNG to PDF: {e}")
205
- else:
206
- print(f"Could not find PNG file to convert: {png_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenizer/tokenizer.py DELETED
@@ -1,147 +0,0 @@
1
- import json
2
- import re
3
- from collections import Counter
4
- import pickle
5
- import argparse
6
-
7
- class Tokenizer:
8
- def __init__(self):
9
- self.special_tokens = ["[PAD]", "[MASK]"]
10
- self.vocab = {}
11
- self.token_to_id = {}
12
- self.id_to_token = {}
13
-
14
- def tokenize(self, text):
15
- # Match words, numbers, periods, and commas as separate tokens
16
- tokens = re.findall(r'\w+|[.,]|\[mask\]|\[pad\]', text.lower())
17
- # Restore MASK and PAD to all caps
18
- modified_list = []
19
- for s in tokens:
20
- modified_s = s.replace("[mask]", "[MASK]").replace("[pad]", "[PAD]")
21
- modified_list.append(modified_s)
22
- return modified_list
23
-
24
- def pad_sequence(self, tokens, length):
25
- """Pads tokenized sequences to length with a padding token (assumed to be '[PAD]')."""
26
- if len(tokens) > length:
27
- raise ValueError(f"Token sequence length {len(tokens)} exceeds specified length {length}.")
28
-
29
- pad_token = self.token_to_id["[PAD]"]
30
- return tokens + [pad_token] * (length - len(tokens))
31
-
32
- def build_vocab(self, dataset_path, min_freq=1):
33
- token_counter = Counter()
34
-
35
- with open(dataset_path, 'r') as f:
36
- data = json.load(f)
37
- for entry in data:
38
- caption = entry['caption']
39
- tokens = self.tokenize(caption)
40
- token_counter.update(tokens)
41
-
42
- # Keep tokens that meet the min frequency
43
- tokens = [tok for tok, count in token_counter.items() if count >= min_freq]
44
-
45
- # Ensure special tokens are always included
46
- all_tokens = self.special_tokens + sorted(tokens)
47
-
48
- # Build vocab dictionaries
49
- self.vocab = {tok: idx for idx, tok in enumerate(all_tokens)}
50
- self.token_to_id = self.vocab
51
- self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
52
-
53
- print(f"Vocabulary size: {len(self.vocab)}")
54
-
55
- def encode(self, text):
56
- tokens = self.tokenize(text)
57
- encoded = []
58
- for tok in tokens:
59
- if tok not in self.token_to_id:
60
- raise ValueError(f"Unknown token encountered: {tok} in {text}")
61
- encoded.append(self.token_to_id[tok])
62
- return encoded
63
-
64
- def encode_batch(self, texts, pad_to_length=None):
65
- """
66
- Encode a batch of texts into token IDs with padding to ensure uniform length.
67
-
68
- Args:
69
- texts (list): A list of strings to encode
70
- pad_to_length (int, optional): Length to pad all sequences to. If None,
71
- will pad to the length of the longest sequence.
72
-
73
- Returns:
74
- list: A list of lists, where each inner list contains the token IDs for a text
75
- """
76
- # Get the padding token ID
77
- pad_token = self.token_to_id["[PAD]"]
78
-
79
- # First encode all texts
80
- encoded_texts = []
81
- for text in texts:
82
- try:
83
- encoded = self.encode(text)
84
- encoded_texts.append(encoded)
85
- except ValueError as e:
86
- raise ValueError(f"Error encoding text: {text}. {str(e)}")
87
-
88
- # Determine padding length
89
- if pad_to_length is None:
90
- pad_to_length = max(len(seq) for seq in encoded_texts)
91
-
92
- # Pad sequences to uniform length
93
- padded_texts = []
94
- for seq in encoded_texts:
95
- if len(seq) > pad_to_length:
96
- # Truncate if too long
97
- padded_texts.append(seq[:pad_to_length])
98
- else:
99
- # Pad if too short
100
- padding = [pad_token] * (pad_to_length - len(seq))
101
- padded_texts.append(seq + padding)
102
-
103
- return padded_texts
104
-
105
- def decode(self, token_ids):
106
- return ' '.join(self.id_to_token[tok_id] for tok_id in token_ids)
107
-
108
- def save(self, path):
109
- with open(path, 'wb') as f:
110
- pickle.dump({'vocab': self.vocab}, f)
111
-
112
- def load(self, path):
113
- with open(path, 'rb') as f:
114
- data = pickle.load(f)
115
- self.vocab = data['vocab']
116
- self.token_to_id = self.vocab
117
- self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
118
-
119
- def get_vocab(self):
120
- return sorted(self.vocab.keys())
121
-
122
- def get_vocab_size(self):
123
- return len(self.vocab)
124
-
125
- if __name__ == "__main__":
126
- tokenizer = Tokenizer()
127
-
128
- parser = argparse.ArgumentParser(description="Tokenizer utility for saving and loading vocabularies.")
129
- parser.add_argument("action", choices=["save", "load"], help="Action to perform: 'save' or 'load'.")
130
- parser.add_argument("--json_file", type=str, default='Mario_LevelsAndCaptions.json', help="Path to the JSON file containing the dataset (required for 'save').")
131
- parser.add_argument("--pkl_file", type=str, default='Mario_Tokenizer.pkl', help="Path to the pickle file to save/load the tokenizer.")
132
-
133
- args = parser.parse_args()
134
-
135
- if args.action == "save":
136
- if not args.json_file:
137
- raise ValueError("The --json_file argument is required for the 'save' action.")
138
- tokenizer.build_vocab(args.json_file)
139
- tokenizer.save(args.pkl_file)
140
- elif args.action == "load":
141
- tokenizer.load(args.pkl_file)
142
-
143
- # Example usage
144
- #print(tokenizer.encode("floor with one gap. one enemy."))
145
- #print(tokenizer.get_vocab())
146
- #for id, token in tokenizer.id_to_token.items():
147
- # print(id,":",token)