schrum2 commited on
Commit
4c0a730
·
verified ·
1 Parent(s): e7ac06f

Deleting directories, moving files into root

Browse files
models/general_training_helper.py DELETED
@@ -1,172 +0,0 @@
1
- from torch.utils.data import DataLoader
2
- from level_dataset import LevelDataset
3
- import random
4
- from util.plotter import Plotter
5
- from datetime import datetime
6
- import os
7
- import threading
8
- import json
9
- import torch.nn.functional as F
10
- import torch
11
-
12
-
13
-
14
-
15
- def create_dataloaders(json_path, val_json, tokenizer, data_mode, augment, num_tiles,
16
- negative_prompt_training, block_embeddings, batch_size):
17
- # Initialize dataset
18
- train_dataset = LevelDataset(
19
- json_path=json_path,
20
- tokenizer=tokenizer,
21
- shuffle=True,
22
- mode=data_mode,
23
- augment=augment,
24
- num_tiles=num_tiles,
25
- negative_captions=negative_prompt_training,
26
- block_embeddings=block_embeddings
27
- )
28
- val_dataset = None
29
- if val_json is not None:
30
- val_dataset = LevelDataset(
31
- json_path=val_json,
32
- tokenizer=tokenizer,
33
- shuffle=False,
34
- mode=data_mode,
35
- augment=False,
36
- num_tiles=num_tiles,
37
- negative_captions=negative_prompt_training,
38
- block_embeddings=block_embeddings
39
- )
40
-
41
- # Create dataloader
42
- train_dataloader = DataLoader(
43
- train_dataset,
44
- batch_size=batch_size,
45
- shuffle=True,
46
- num_workers=4,
47
- drop_last=True,
48
- persistent_workers=True
49
- )
50
-
51
- val_dataloader = None
52
- if val_dataset is not None:
53
- val_dataloader = DataLoader(
54
- val_dataset,
55
- batch_size=batch_size,
56
- shuffle=False,
57
- num_workers=4,
58
- drop_last=False,
59
- persistent_workers=True
60
- )
61
-
62
- return train_dataloader, val_dataloader
63
-
64
-
65
- def get_random_training_samples(train_dataloader, negative_prompt_training, output_dir = None):
66
- train_dataset = train_dataloader.dataset
67
- # Sample four random captions from the dataset
68
- sample_indices = [random.randint(0, len(train_dataset) - 1) for _ in range(4)]
69
-
70
- sample_captions = [train_dataset[i][1] for i in sample_indices]
71
- print("Sample captions:")
72
- for caption in sample_captions:
73
- print(caption)
74
-
75
- sample_negative_captions = ""
76
- if negative_prompt_training:
77
- sample_negative_captions = [train_dataset[i][2] for i in sample_indices]
78
- print("Sample negative captions:")
79
- for caption in sample_negative_captions:
80
- print(f" NEG: {caption}")
81
-
82
- #Write captions to a file
83
- if output_dir is not None:
84
- os.makedirs(output_dir, exist_ok=True)
85
- out_path = os.path.join(output_dir, "sample_captions.txt")
86
- with open(out_path, "w", encoding="utf-8") as f:
87
- f.write("Sample captions:\n")
88
- for caption in sample_captions:
89
- f.write(str(caption) + "\n")
90
- if negative_prompt_training:
91
- f.write("\nSample negative captions:\n")
92
- for caption in sample_negative_captions:
93
- f.write(str(caption) + "\n")
94
- print(f"Sample captions written to {out_path}")
95
-
96
-
97
- return sample_captions, sample_negative_captions
98
-
99
-
100
- def start_plotter(log_file, output_dir, left_key, right_key, left_label, right_label, png_name):
101
- formatted_date = datetime.now().strftime(r'%Y%m%d-%H%M%S')
102
-
103
- plotter = Plotter(log_file, update_interval=5.0, left_key=left_key, right_key=right_key,
104
- left_label=left_label, right_label=right_label, output_png=f'{png_name}_{formatted_date}.png')
105
- plot_thread = threading.Thread(target=plotter.start_plotting)
106
- plot_thread.daemon = True
107
- plot_thread.start()
108
- print(f"{png_name} plotting enabled. Progress will be saved to {os.path.join(output_dir, f'{png_name}_{formatted_date}.png')}")
109
- return plotter, plot_thread
110
-
111
-
112
- def kill_plotter(plotter, plot_thread):
113
- if plot_thread and plot_thread.is_alive():
114
- plotter.stop_plotting()
115
- plot_thread.join(timeout=5.0)
116
- if plot_thread.is_alive():
117
- print("Warning: Plot thread did not terminate properly")
118
-
119
-
120
- def load_config_from_json(config_path):
121
- """Load hyperparameters from a JSON config file."""
122
- try:
123
- with open(config_path, 'r') as f:
124
- config = json.load(f)
125
- print(f"Configuration loaded from {config_path}")
126
-
127
- # Print the loaded config for verification
128
- print("Loaded hyperparameters:")
129
- for key, value in config.items():
130
- print(f" {key}: {value}")
131
-
132
- return config
133
- except (json.JSONDecodeError, FileNotFoundError) as e:
134
- print(f"Error loading config file: {e}")
135
- raise e
136
-
137
-
138
- def update_args_from_config(args, config):
139
- """Update argparse namespace with values from config."""
140
- # Convert config dict to argparse namespace
141
- for key, value in config.items():
142
- if hasattr(args, key):
143
- setattr(args, key, value)
144
- return args
145
-
146
-
147
- def get_scene_from_embeddings(image, block_embeddings):
148
- """Code copied over from level_dataset, should give limited support for block embeddings"""
149
- # Reshape sample to [batch_size * height * width, embedding_dim]
150
- batch_size, embedding_dim, height, width = image.shape
151
-
152
- flat_samples = image.permute(0, 2, 3, 1).reshape(-1, embedding_dim)
153
-
154
- # Normalize vectors for cosine similarity
155
- flat_samples = F.normalize(flat_samples, p=2, dim=1).cpu()
156
- block_embeddings = F.normalize(block_embeddings, p=2, dim=1)
157
-
158
- # Calculate cosine similarity between each position and all tile embeddings
159
- similarities = torch.matmul(flat_samples, block_embeddings.t())
160
-
161
- # Get indices of most similar tiles
162
- indices = torch.softmax(similarities, dim=1)
163
-
164
-
165
- # Reshape back to [batch_size, height, width]
166
- indices = indices.reshape(batch_size, height, width, 13)
167
- indices = indices.permute(0, 3, 1, 2)
168
-
169
- image=indices.detach().cpu()
170
- return image
171
-
172
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/latent_diffusion_pipeline.py DELETED
@@ -1,99 +0,0 @@
1
- from diffusers import DDPMPipeline
2
- import torch
3
- import torch.nn.functional as F
4
- from typing import Optional, Union, List, Tuple
5
- from diffusers.utils.torch_utils import randn_tensor
6
- from diffusers.pipelines.ddpm.pipeline_ddpm import ImagePipelineOutput
7
- import util.common_settings as common_settings
8
- import os
9
- import json
10
- from models.general_training_helper import get_scene_from_embeddings
11
-
12
- class UnconditionalDDPMPipeline(DDPMPipeline):
13
- def __init__(self, unet, scheduler, block_embeddings=None):
14
- super().__init__(unet, scheduler)
15
-
16
- self.block_embeddings = block_embeddings
17
-
18
-
19
- def save_pretrained(self, save_directory):
20
- os.makedirs(save_directory, exist_ok=True)
21
- super().save_pretrained(save_directory)
22
- # Save block_embeddings tensor if it exists
23
- if self.block_embeddings is not None:
24
- torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))
25
-
26
- @classmethod
27
- def from_pretrained(cls, pretrained_model_path, **kwargs):
28
- pipeline = super().from_pretrained(pretrained_model_path, **kwargs)
29
- # Load block_embeddings tensor if it exists
30
- block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
31
- if os.path.exists(block_embeds_path):
32
- pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
33
- else:
34
- pipeline.block_embeddings = None
35
- return pipeline
36
-
37
-
38
-
39
- def give_sprite_scaling_factors(self, sprite_scaling_factors):
40
- """
41
- Set the sprite scaling factors for the pipeline.
42
- This is used to apply per-sprite temperature scaling during inference.
43
- """
44
- self.sprite_scaling_factors = sprite_scaling_factors
45
-
46
- def __call__(
47
- self,
48
- batch_size: int = 1,
49
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
50
- num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,
51
- output_type: Optional[str] = "tensor",
52
- return_dict: bool = True,
53
- height: int = common_settings.MARIO_HEIGHT, width: int = common_settings.MARIO_WIDTH,
54
- latents: Optional[torch.FloatTensor] = None,
55
- show_progress_bar=True,
56
- ) -> Union[ImagePipelineOutput, Tuple]:
57
-
58
- self.unet.eval()
59
- with torch.no_grad():
60
-
61
- if latents is not None:
62
- image = latents.to(self.device)
63
- else:
64
- image_shape = (
65
- batch_size,
66
- self.unet.config.in_channels,
67
- height,
68
- width
69
- )
70
-
71
- image = torch.randn(image_shape, generator=generator, device=self.device)
72
-
73
- self.scheduler.set_timesteps(num_inference_steps)
74
-
75
- iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
76
- for t in iterator:
77
- #print(image.shape)
78
- model_output = self.unet(image, t).sample
79
- image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
80
-
81
- # Apply per-sprite temperature scaling if enabled
82
- if hasattr(self,"sprite_scaling_factors") and self.sprite_scaling_factors is not None:
83
- image = image / self.sprite_scaling_factors.view(1, -1, 1, 1)
84
-
85
-
86
- if self.block_embeddings is not None:
87
- image = get_scene_from_embeddings(image, self.block_embeddings)
88
- else:
89
- image = F.softmax(image, dim=1)
90
- image = image.detach().cpu()
91
-
92
- if not return_dict:
93
- return (image,)
94
-
95
- return ImagePipelineOutput(images=image)
96
-
97
- def print_unet_architecture(self):
98
- """Prints the architecture of the UNet model."""
99
- print(self.unet)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/pipeline_loader.py DELETED
@@ -1,41 +0,0 @@
1
- from models.text_diffusion_pipeline import TextConditionalDDPMPipeline
2
- from models.latent_diffusion_pipeline import UnconditionalDDPMPipeline
3
- import os
4
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
5
-
6
-
7
- def get_pipeline(model_path):
8
- # If model_path is a local directory, use the original logic
9
- if os.path.isdir(model_path):
10
- #Diffusion models
11
- if os.path.exists(os.path.join(model_path, "unet")):
12
- if os.path.exists(os.path.join(model_path, "text_encoder")):
13
- #If it has a text encoder and a unet, it's text conditional diffusion
14
- pipe = TextConditionalDDPMPipeline.from_pretrained(model_path)
15
- else:
16
- #If it has no text encoder, use the unconditional diffusion model
17
- pipe = UnconditionalDDPMPipeline.from_pretrained(model_path)
18
- else:
19
- # Assume it's a Hugging Face Hub model ID
20
- # Try to load config to determine if it's text-conditional
21
- try:
22
- config, _ = DiffusionPipeline.load_config(model_path)
23
- components = config.get("components", {})
24
- except Exception:
25
- components = {}
26
- if "text_encoder" in components or "text_encoder" in str(components):
27
- # Use the local pipeline file for custom_pipeline
28
- pipe = DiffusionPipeline.from_pretrained(
29
- model_path,
30
- custom_pipeline="models.text_diffusion_pipeline.TextConditionalDDPMPipeline",
31
- trust_remote_code=True,
32
- )
33
- else:
34
- # Fallback: try unconditional
35
- pipe = DiffusionPipeline.from_pretrained(
36
- model_path,
37
- custom_pipeline="models.latent_diffusion_pipeline.UnconditionalDDPMPipeline",
38
- trust_remote_code=True,
39
- )
40
-
41
- return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/sentence_transformers_helper.py DELETED
@@ -1,114 +0,0 @@
1
- from transformers import AutoTokenizer, AutoModel
2
- import torch
3
- import torch.nn.functional as F
4
-
5
- #Mean Pooling - Take average of all tokens
6
- def mean_pooling(model_output, attention_mask):
7
- token_embeddings = model_output.last_hidden_state
8
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
9
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
10
-
11
-
12
- #Encode text
13
- def encode(texts, tokenizer, model, device='cpu'):
14
- # Tokenize sentences
15
- encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
16
- encoded_input.to(device)
17
-
18
- # Compute token embeddings
19
- with torch.no_grad():
20
- model_output = model(**encoded_input, return_dict=True)
21
-
22
- # Perform pooling
23
- embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
24
-
25
- # Normalize embeddings
26
- embeddings = F.normalize(embeddings, p=2, dim=1)
27
-
28
- embeddings = embeddings.to(device)
29
-
30
- return embeddings
31
-
32
- # Get embeddings for a batch of captions and optional negative captions
33
- def get_embeddings(batch_size, tokenizer, model, captions=None, neg_captions=None, device='cpu'):
34
- embeddings = encode([""]*batch_size, tokenizer, model, device)
35
-
36
- if captions is not None:
37
- caption_embeddings = encode(captions, tokenizer, model, device)
38
- embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
39
-
40
- if neg_captions is not None:
41
- neg_embeddings = encode(neg_captions, tokenizer, model, device)
42
- embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
43
-
44
-
45
- embeddings = embeddings.unsqueeze(1)
46
-
47
- return embeddings
48
-
49
-
50
-
51
-
52
- def get_embeddings_split(batch_size, tokenizer, model, captions=None, neg_captions=None, device='cpu', max_length=20):
53
-
54
- padding_length = max(max([s.count(".") for s in captions]) if captions else 1,
55
- max([s.count(".") for s in neg_captions]) if neg_captions else 1)
56
- if (padding_length>max_length):
57
- raise ValueError(f"Token sequence length {padding_length} exceeds specified length {max_length}.")
58
-
59
-
60
- empty_split = split_sentences([""] * batch_size, padding_length)
61
- embeddings = get_embeddings_from_split(empty_split, tokenizer, model, device)
62
-
63
- if(captions is not None):
64
- captions_split = split_sentences(captions, padding_length)
65
- caption_embeddings = get_embeddings_from_split(captions_split, tokenizer, model, device)
66
- embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
67
-
68
- if(neg_captions is not None):
69
- neg_split = split_sentences(neg_captions, padding_length)
70
- neg_embeddings = get_embeddings_from_split(neg_split, tokenizer, model, device)
71
- embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
72
-
73
-
74
- #We don't need to unsqueeze this, we have an array of (batch_size, padding_length, encoding_size) already
75
-
76
- return embeddings.to(device)
77
-
78
-
79
- #This method takes a caption batch in list form, and outputs a 2d list where every caption has been split by period
80
- def split_sentences(caption_array, padding_length=20):
81
- split_caption_array = []
82
-
83
- #Padding happens here
84
- for caption in caption_array:
85
- split_caption = [s.strip() for s in caption.split(".") if s.strip()]
86
- #This is the token padding, we just use an empty string
87
- split_caption += [""] * (padding_length - len(split_caption))
88
- split_caption_array.append(split_caption)
89
-
90
- return split_caption_array
91
-
92
-
93
- #Expects all split vectors to be the same length
94
- def get_embeddings_from_split(caption_batch, tokenizer, model, device='cpu'):
95
- all_caption_encodings = []
96
- for caption_sequence in caption_batch:
97
- #Encode the sequence of split captions as if it was a batch, should now be a [maxlength, embeddingsize] tensor
98
- caption_sequence = encode(caption_sequence, tokenizer, model, device)
99
-
100
- #We don't reshape this to avoid having to unsqueeze it later
101
- all_caption_encodings.append(caption_sequence)
102
-
103
- all_caption_encodings = torch.stack(all_caption_encodings, dim=0)
104
- return all_caption_encodings
105
-
106
-
107
-
108
- if __name__ == "__main__":
109
- cap = split_sentences(["Hello. My name is George. How. Are you doing. Today?", "I am doing. Just fine. Thanks."])
110
- model_url = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
111
- device = 'cuda'
112
- tokenizer = AutoTokenizer.from_pretrained(model_url)
113
- model = AutoModel.from_pretrained(model_url, trust_remote_code=True).to(device)
114
- get_embeddings_from_split(cap, tokenizer, model, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/text_diffusion_pipeline.py DELETED
@@ -1,442 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from typing import NamedTuple, Optional
4
- import os
5
- from diffusers import DDPMPipeline, UNet2DConditionModel, DDPMScheduler
6
- import json
7
- # Running the main at the end of this requires messing with this import
8
- from models.text_model import TransformerModel
9
- import torch
10
- import torch.nn.functional as F
11
- from transformers import AutoTokenizer, AutoModel
12
- import util.common_settings as common_settings
13
- import models.sentence_transformers_helper as st_helper
14
- import models.text_model as text_model
15
- from models.general_training_helper import get_scene_from_embeddings
16
-
17
- class PipelineOutput(NamedTuple):
18
- images: torch.Tensor
19
-
20
-
21
-
22
- # Create a custom pipeline for text-conditional generation
23
- class TextConditionalDDPMPipeline(DDPMPipeline):
24
- def __init__(self, unet, scheduler, text_encoder=None, tokenizer=None, supports_pretrained_split=False, block_embeddings=None):
25
- super().__init__(unet=unet, scheduler=scheduler)
26
- self.text_encoder = text_encoder
27
- self.tokenizer = tokenizer
28
- self.supports_negative_prompt = hasattr(unet, 'negative_prompt_support') and unet.negative_prompt_support
29
- self.supports_pretrained_split = supports_pretrained_split
30
- self.block_embeddings = block_embeddings
31
-
32
- if self.tokenizer is None and self.text_encoder is not None:
33
- # Use the tokenizer from the text encoder if not provided
34
- self.tokenizer = self.text_encoder.tokenizer
35
-
36
- # Register the text_encoder so that .to(), .cpu(), .cuda(), etc. work correctly
37
- self.register_modules(
38
- unet=unet,
39
- scheduler=scheduler,
40
- text_encoder=self.text_encoder,
41
- tokenizer=self.tokenizer,
42
- )
43
-
44
- # Override the to() method to ensure text_encoder is moved to the correct device
45
- def to(self, device=None, dtype=None):
46
- # Call the parent's to() method first
47
- pipeline = super().to(device, dtype)
48
-
49
- # Additionally move the text_encoder to the device
50
- if self.text_encoder is not None:
51
- self.text_encoder.to(device)
52
-
53
- return pipeline
54
-
55
- def save_pretrained(self, save_directory):
56
- os.makedirs(save_directory, exist_ok=True)
57
- super().save_pretrained(save_directory) # saves UNet and scheduler
58
-
59
- # Save block_embeddings tensor if it exists
60
- if self.block_embeddings is not None:
61
- torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))
62
-
63
- # Save supports_negative_prompt and supports_pretrained_split flags
64
- with open(os.path.join(save_directory, "pipeline_config.json"), "w") as f:
65
- json.dump({
66
- "supports_negative_prompt": self.supports_negative_prompt,
67
- "supports_pretrained_split": self.supports_pretrained_split,
68
- "text_encoder_type": type(self.text_encoder).__name__
69
- }, f)
70
-
71
-
72
- #Text encoder/tokenizer saving is different depending on if we're using a larger pretrained model
73
- if isinstance(self.text_encoder, TransformerModel):
74
- # Save custom text encoder
75
- if self.text_encoder is not None:
76
- self.text_encoder.save_pretrained(os.path.join(save_directory, "text_encoder"))
77
- else:
78
- #Save pretrained tokenizer by name, so we can load from huggingface instead of saving a giant local model
79
- text_encoder_info = {
80
- "text_encoder_name": self.text_encoder.config.name_or_path,
81
- "tokenizer_name": self.tokenizer.name_or_path,
82
- }
83
-
84
- text_encoder_directory = os.path.join(save_directory, "text_encoder")
85
- os.makedirs(text_encoder_directory, exist_ok=True)
86
-
87
- with open(os.path.join(text_encoder_directory, "loading_info.json"), "w") as f:
88
- json.dump(text_encoder_info, f)
89
-
90
-
91
-
92
- @classmethod
93
- def from_pretrained(cls, pretrained_model_path, **kwargs):
94
- #from diffusers.utils import load_config, load_state_dict
95
- # Load model_index.json
96
- #model_index = load_config(pretrained_model_path)
97
-
98
- # Load components manually
99
- unet_path = os.path.join(pretrained_model_path, "unet")
100
- unet = UNet2DConditionModel.from_pretrained(unet_path)
101
-
102
- scheduler_path = os.path.join(pretrained_model_path, "scheduler")
103
- # Have heard that DDIMScheduler might be faster for inference, though not necessarily better
104
- scheduler = DDPMScheduler.from_pretrained(scheduler_path)
105
-
106
- tokenizer = None
107
- text_encoder_path = os.path.join(pretrained_model_path, "text_encoder")
108
-
109
- if os.path.exists(text_encoder_path):
110
- #Test for the new saving system, where we save a simple config file
111
- if os.path.exists(os.path.join(text_encoder_path, "loading_info.json")):
112
- with open(os.path.join(text_encoder_path, "loading_info.json"), "r") as f:
113
- encoder_config = json.load(f)
114
-
115
- text_encoder = AutoModel.from_pretrained(encoder_config['text_encoder_name'], trust_remote_code=True)
116
- tokenizer = AutoTokenizer.from_pretrained(encoder_config['tokenizer_name'])
117
-
118
- #Legacy loading system, loads models directly if the whole thing is saved in the directory
119
- else:
120
- try:
121
- text_encoder = AutoModel.from_pretrained(text_encoder_path, local_files_only=True, trust_remote_code=True)
122
- tokenizer = AutoTokenizer.from_pretrained(text_encoder_path, local_files_only=True)
123
- except (ValueError, KeyError):
124
- text_encoder = TransformerModel.from_pretrained(text_encoder_path)
125
- tokenizer = text_encoder.tokenizer
126
- else:
127
- text_encoder = None
128
-
129
- # Instantiate your pipeline
130
- pipeline = cls(
131
- unet=unet,
132
- scheduler=scheduler,
133
- text_encoder=text_encoder,
134
- tokenizer=tokenizer,
135
- **kwargs,
136
- )
137
-
138
- #Loads block embeddings if present
139
- block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
140
- if os.path.exists(block_embeds_path):
141
- pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
142
- else:
143
- pipeline.block_embeddings = None
144
-
145
-
146
- # Load supports_negative_prompt flag if present
147
- config_path = os.path.join(pretrained_model_path, "pipeline_config.json")
148
- if os.path.exists(config_path):
149
- with open(config_path, "r") as f:
150
- config = json.load(f)
151
- pipeline.supports_negative_prompt = config.get("supports_negative_prompt", False)
152
- pipeline.supports_pretrained_split = config.get("supports_pretrained_split", False)
153
- return pipeline
154
-
155
- # --- Handle batching for captions ---
156
- def _prepare_text_batch(self, text: Optional[str | list[str]], batch_size: int, name: str) -> Optional[list[str]]:
157
- if text is None:
158
- return None
159
- if isinstance(text, str):
160
- return [text] * batch_size
161
- if isinstance(text, list):
162
- if len(text) == 1:
163
- return text * batch_size
164
- if len(text) != batch_size:
165
- raise ValueError(f"{name} list length {len(text)} does not match batch_size {batch_size}")
166
- return text
167
- raise ValueError(f"{name} must be a string or list of strings")
168
-
169
- def _prepare_initial_sample(self,
170
- raw_latent_sample: Optional[torch.Tensor],
171
- input_scene: Optional[torch.Tensor],
172
- batch_size: int, height: int, width: int,
173
- generator: Optional[torch.Generator]) -> torch.Tensor:
174
- """Prepare the initial sample for diffusion."""
175
-
176
- sample_shape = (batch_size, self.unet.config.in_channels, height, width)
177
-
178
- if raw_latent_sample is not None:
179
- if input_scene is not None:
180
- raise ValueError("Cannot provide both raw_latent_sample and input_scene")
181
- sample = raw_latent_sample.to(self.device)
182
- if sample.shape[1] != sample_shape[1]:
183
- raise ValueError(f"Wrong number of channels in raw_latent_sample: Expected {self.unet.config.in_channels} but got {sample.shape[1]}")
184
- if sample.shape[0] == 1 and batch_size > 1:
185
- sample = sample.repeat(batch_size, 1, 1, 1)
186
- elif sample.shape[0] != batch_size:
187
- raise ValueError(f"raw_latent_sample batch size {sample.shape[0]} does not match batch_size {batch_size}")
188
- elif input_scene is not None:
189
- # input_scene can be (H, W) or (batch_size, H, W)
190
- scene_tensor = torch.tensor(input_scene, dtype=torch.long, device=self.device)
191
- if scene_tensor.dim() == 2:
192
- # (H, W) -> repeat for batch
193
- scene_tensor = scene_tensor.unsqueeze(0).repeat(batch_size, 1, 1)
194
- elif scene_tensor.shape[0] == 1 and batch_size > 1:
195
- scene_tensor = scene_tensor.repeat(batch_size, 1, 1)
196
- elif scene_tensor.shape[0] != batch_size:
197
- raise ValueError(f"input_scene batch size {scene_tensor.shape[0]} does not match batch_size {batch_size}")
198
- # One-hot encode: (batch, H, W, C)
199
- one_hot = F.one_hot(scene_tensor, num_classes=self.unet.config.in_channels).float()
200
- # (batch, H, W, C) -> (batch, C, H, W)
201
- sample = one_hot.permute(0, 3, 1, 2)
202
- else:
203
- # Start from random noise
204
- sample = torch.randn(sample_shape, generator=generator, device=self.device)
205
-
206
- return sample
207
-
208
- def __call__(
209
- self,
210
- caption: Optional[str | list[str]] = None,
211
- negative_prompt: Optional[str | list[str]] = None,
212
- generator: Optional[torch.Generator] = None,
213
- num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,
214
- guidance_scale: float = common_settings.GUIDANCE_SCALE,
215
- height: int = common_settings.MARIO_HEIGHT,
216
- width: int = common_settings.MARIO_WIDTH,
217
- raw_latent_sample: Optional[torch.FloatTensor] = None,
218
- input_scene: Optional[torch.Tensor] = None,
219
- output_type: str = "tensor",
220
- batch_size: int = 1,
221
- show_progress_bar: bool = True,
222
- ) -> PipelineOutput:
223
- """Generate a batch of images based on text input using the diffusion model.
224
-
225
- Args:
226
- caption: Text description(s) of the desired output. Can be a string or list of strings.
227
- negative_prompt: Text description(s) of what should not appear in the output. String or list.
228
- generator: Random number generator for reproducibility.
229
- num_inference_steps: Number of denoising steps (more = higher quality, slower).
230
- guidance_scale: How strongly the generation follows the text prompt (higher = stronger).
231
- height: Height of generated image in tiles.
232
- width: Width of generated image in tiles.
233
- raw_latent_sample: Optional starting point for diffusion instead of random noise.
234
- Must have correct number of channels matching the UNet.
235
- input_scene: Optional 2D or 3D int tensor where each value corresponds to a tile type.
236
- Will be converted to one-hot encoding as starting point.
237
- output_type: Currently only "tensor" is supported.
238
- batch_size: Number of samples to generate in parallel.
239
-
240
- Returns:
241
- PipelineOutput containing the generated image tensor (batch_size, ...).
242
- """
243
-
244
- # I would like to simplify the code to this, but the AI suggestion didn't work, and
245
- # I did not feel good just pasting it all in. Will need to tackle it bit by bit.
246
-
247
- # if caption is not None and self.text_encoder is None:
248
- # raise ValueError("Text encoder required for conditional generation")
249
-
250
- # self.unet.eval()
251
- # if self.text_encoder is not None:
252
- # self.text_encoder.to(self.device)
253
- # self.text_encoder.eval()
254
- #
255
- # with torch.no_grad():
256
- # # Process text inputs
257
- # captions = self.prepare_text_batch(caption, batch_size, "caption")
258
- # negatives = self.prepare_text_batch(negative_prompt, batch_size, "negative_prompt")
259
-
260
- # # Get embeddings
261
- # text_embeddings = self.prepare_embeddings(captions, negatives, batch_size)
262
- #
263
- # # Set up initial latent state
264
- # sample = self.prepare_initial_sample(raw_latent_sample, input_scene,
265
- # batch_size, height, width, generator)
266
-
267
- # # Run diffusion process
268
- # sample = self.run_diffusion(sample, text_embeddings, num_inference_steps,
269
- # guidance_scale, generator, show_progress_bar,
270
- # has_caption=caption is not None,
271
- # has_negative=negative_prompt is not None)
272
-
273
- # # Format output
274
- # if output_type == "tensor":
275
- # sample = F.softmax(sample, dim=1)
276
- # else:
277
- # raise ValueError(f"Unsupported output type: {output_type}")
278
-
279
- # return PipelineOutput(images=sample)
280
-
281
- # Validate text encoder if we need it
282
- if caption is not None and self.text_encoder is None:
283
- raise ValueError("Text encoder is required for conditional generation")
284
-
285
- self.unet.eval()
286
- if self.text_encoder is not None:
287
- self.text_encoder.to(self.device)
288
- self.text_encoder.eval()
289
-
290
- with torch.no_grad():
291
- captions = self._prepare_text_batch(caption, batch_size, "caption")
292
- negatives = self._prepare_text_batch(negative_prompt, batch_size, "negative_prompt")
293
-
294
- # --- Prepare text embeddings ---
295
- if(isinstance(self.text_encoder, TransformerModel)):
296
- text_embeddings = text_model.get_embeddings(batch_size=batch_size,
297
- tokenizer=self.text_encoder.tokenizer,
298
- text_encoder=self.text_encoder,
299
- captions=captions,
300
- neg_captions=negatives,
301
- device=self.device)
302
- else: #Case for the pre-trained text encoder
303
- if(self.supports_pretrained_split): #If we have a split flag incorporated
304
- text_embeddings = st_helper.get_embeddings_split(batch_size = batch_size,
305
- tokenizer=self.tokenizer,
306
- model=self.text_encoder,
307
- captions=captions,
308
- neg_captions=negatives,
309
- device=self.device)
310
- else:
311
- text_embeddings = st_helper.get_embeddings(batch_size = batch_size,
312
- tokenizer=self.tokenizer,
313
- model=self.text_encoder,
314
- captions=captions,
315
- neg_captions=negatives,
316
- device=self.device)
317
-
318
-
319
- # --- Set up initial latent state ---
320
- sample = self._prepare_initial_sample(raw_latent_sample, input_scene,
321
- batch_size, height, width, generator)
322
-
323
- # --- Set up diffusion process ---
324
- self.scheduler.set_timesteps(num_inference_steps)
325
-
326
- # Denoising loop
327
- iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
328
- for t in iterator:
329
- # Handle conditional generation
330
- if captions is not None:
331
- if negatives is not None:
332
- # Three copies for negative prompt guidance
333
- model_input = torch.cat([sample, sample, sample], dim=0)
334
- else:
335
- # Two copies for standard classifier-free guidance
336
- model_input = torch.cat([sample, sample], dim=0)
337
- else:
338
- model_input = sample
339
-
340
- # Predict noise residual
341
- model_kwargs = {"encoder_hidden_states": text_embeddings}
342
- noise_pred = self.unet(model_input, t, **model_kwargs).sample
343
-
344
- # Apply guidance
345
- if captions is not None:
346
- if negatives is not None:
347
- # Split predictions for negative, unconditional, and text-conditional
348
- noise_pred_neg, noise_pred_uncond, noise_pred_text = noise_pred.chunk(3)
349
- noise_pred_guided = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
350
- noise_pred = noise_pred_guided - guidance_scale * (noise_pred_neg - noise_pred_uncond)
351
- else:
352
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
353
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
354
-
355
- # Compute previous sample: x_{t-1} = scheduler(x_t, noise_pred)
356
- sample = self.scheduler.step(noise_pred, t, sample, generator=generator).prev_sample
357
-
358
- # Convert to output format
359
- if output_type == "tensor":
360
- if self.block_embeddings is not None:
361
- sample = get_scene_from_embeddings(sample, self.block_embeddings)
362
- else:
363
- # Apply softmax to get probabilities for each tile type
364
- sample = F.softmax(sample, dim=1)
365
- sample = sample.detach().cpu()
366
- else:
367
- raise ValueError(f"Unsupported output type: {output_type}")
368
-
369
- return PipelineOutput(images=sample)
370
-
371
- def print_unet_architecture(self):
372
- """Prints the architecture of the UNet model."""
373
- print(self.unet)
374
-
375
- def print_text_encoder_architecture(self):
376
- """Prints the architecture of the text encoder model, if it exists."""
377
- if self.text_encoder is not None:
378
- print(self.text_encoder)
379
- else:
380
- print("No text encoder is set.")
381
-
382
- def save_unet_architecture_pdf(self, height, width, filename="unet_architecture", batch_size=1, device=None):
383
- """
384
- Have to separately install torchview for this to work
385
-
386
- Saves a visualization of the UNet architecture as a PDF using torchview.
387
- Args:
388
- height: Height of the dummy input.
389
- width: Width of the dummy input.
390
- filename: Output PDF filename.
391
- batch_size: Batch size for dummy input.
392
- device: Device to run the dummy input on (defaults to pipeline device).
393
- """
394
- from torchview import draw_graph
395
- import graphviz
396
-
397
- if device is None:
398
- device = self.device if hasattr(self, 'device') else 'cpu'
399
- in_channels = self.unet.config.in_channels if hasattr(self.unet, 'config') else 1
400
- sample_shape = tuple([batch_size, in_channels, height, width])
401
-
402
- dummy_x = torch.randn(size=sample_shape, device=device)
403
- dummy_t = torch.tensor([0] * batch_size, dtype=torch.long, device=device)
404
-
405
- # Prepare dummy text embedding (match what your UNet expects)
406
- if hasattr(self.unet, 'config') and hasattr(self.unet.config, 'cross_attention_dim'):
407
- cross_attention_dim = self.unet.config.cross_attention_dim
408
- else:
409
- cross_attention_dim = 128 # fallback
410
- encoder_hidden_states = torch.randn(batch_size, 1, cross_attention_dim, device=device)
411
-
412
- self.unet.eval()
413
- inputs = (dummy_x, dummy_t, encoder_hidden_states)
414
- #self.unet.down_blocks = self.unet.down_blocks[:2]
415
-
416
- graph = draw_graph(
417
- model=self.unet,
418
- input_data=inputs,
419
- expand_nested=False,
420
- #enable_output_shape=True,
421
- #roll_out="nested",
422
- depth=1
423
- )
424
- #graph.visual_graph.engine = "neato"
425
- graph.visual_graph.attr(#rankdir="LR",
426
- nodesep="0.1", # decrease space between nodes in the same rank (default ~0.25)
427
- ranksep="0.2", # decrease space between ranks (default ~0.5)
428
- concentrate="true" # merge edges between nodes in the same rank
429
- )
430
- graph.visual_graph.node_attr.update(
431
- shape="rectangle",
432
- width="1.5", # narrow width
433
- height="0.5" # taller height to make vertical rectangles
434
- #fixedsize="true"
435
- )
436
-
437
- graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
438
- graph.visual_graph.save('unet_architecture.dot')
439
-
440
- # Save the graph to a PDF file
441
- print(f"UNet architecture saved to {filename}")
442
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/text_model.py DELETED
@@ -1,206 +0,0 @@
1
- import argparse
2
- from xml.parsers.expat import model
3
- import torch
4
- import torch.nn as nn
5
- import math
6
- import os
7
- import json
8
- from safetensors.torch import save_file, load_file
9
- from tokenizer import Tokenizer
10
-
11
- def get_embeddings(batch_size, tokenizer, text_encoder, captions=None, neg_captions=None, device='cpu'):
12
- max_length = text_encoder.max_seq_length
13
- empty_ids = encode_token_captions([""] * batch_size, tokenizer, max_length, device=device)
14
- embeddings = text_encoder.get_embeddings(empty_ids)
15
-
16
- if(captions is not None):
17
- caption_ids = encode_token_captions(captions, tokenizer, max_length, device=device)
18
- caption_embeddings = text_encoder.get_embeddings(caption_ids)
19
- embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
20
-
21
- if(neg_captions is not None):
22
- neg_ids = encode_token_captions(neg_captions, tokenizer, max_length, device=device)
23
- neg_embeddings = text_encoder.get_embeddings(neg_ids)
24
- embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
25
-
26
- return embeddings.to(device)
27
-
28
- def encode_token_captions(captions, tokenizer, max_length, device='cpu'):
29
- caption_ids = []
30
- for caption in captions:
31
- tokens = tokenizer.encode(caption)
32
- caption_tokens = tokenizer.pad_sequence(tokens, max_length)
33
- caption_ids.append(torch.tensor(caption_tokens, dtype=torch.long).unsqueeze(0))
34
- return torch.cat(caption_ids, dim=0).to(device)
35
-
36
-
37
-
38
-
39
-
40
-
41
-
42
-
43
-
44
- # Transformer model for MLM training
45
-
46
- class TransformerModel(nn.Module):
47
- def __init__(self, vocab_size, embedding_dim, hidden_dim, tokenizer=None, num_heads=8, num_layers=4, max_seq_length=100):
48
- super().__init__()
49
- self.embedding_dim = embedding_dim
50
- self.vocab_size = vocab_size
51
- self.hidden_dim = hidden_dim
52
- self.num_heads = num_heads
53
- self.num_layers = num_layers
54
- self.max_seq_length = max_seq_length
55
-
56
- self.embedding = nn.Embedding(vocab_size, embedding_dim)
57
- self.positional_encoding = self.create_positional_encoding(max_seq_length, embedding_dim)
58
-
59
- encoder_layers = nn.TransformerEncoderLayer(
60
- d_model=embedding_dim,
61
- nhead=num_heads,
62
- dim_feedforward=hidden_dim,
63
- batch_first=True
64
- )
65
- self.transformer = nn.TransformerEncoder(encoder_layers, num_layers)
66
- self.fc = nn.Linear(embedding_dim, vocab_size)
67
-
68
- self.tokenizer = tokenizer
69
-
70
- def create_positional_encoding(self, max_seq_length, embedding_dim):
71
- # The implementation uses a sinusoidal positional encoding, which creates a unique pattern for each position in the sequence.
72
- # The frequencies create unique values, the sin/cos bounds values
73
- position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
74
- # Creates a set of divisors that create different frequencies
75
- div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
76
- pe = torch.zeros(max_seq_length, embedding_dim)
77
- # Even dimensions use sin, odd dimensions use cos
78
- pe[:, 0::2] = torch.sin(position * div_term)
79
- pe[:, 1::2] = torch.cos(position * div_term)
80
- return pe.unsqueeze(0)
81
-
82
- def get_embeddings(self, x):
83
- """ This gets the actual latent embedding vectors """
84
- # Ensure positional encoding is on the same device as input
85
- pe = self.positional_encoding[:, :x.size(1), :].to(x.device)
86
- # Embed input and add positional encoding
87
- embedded = self.embedding(x) + pe
88
- return self.transformer(embedded)
89
-
90
- def forward(self, x):
91
- """ This gets the token within the vocabulary """
92
- transformer_out = self.get_embeddings(x)
93
- # Project to vocabulary size
94
- return self.fc(transformer_out)
95
-
96
- def save_pretrained(self, save_directory):
97
- os.makedirs(save_directory, exist_ok=True)
98
-
99
- config = {
100
- "vocab_size": self.vocab_size,
101
- "embedding_dim": self.embedding_dim,
102
- "hidden_dim": self.hidden_dim,
103
- "num_heads": self.num_heads,
104
- "num_layers": self.num_layers,
105
- "max_seq_length": self.max_seq_length,
106
- }
107
- with open(os.path.join(save_directory, "config.json"), "w") as f:
108
- json.dump(config, f)
109
-
110
- # Save model weights
111
- save_file(self.state_dict(), os.path.join(save_directory, "model.safetensors"))
112
-
113
- # Save tokenizer if present
114
- if self.tokenizer is not None:
115
- self.tokenizer.save(os.path.join(save_directory, "tokenizer.pkl"))
116
-
117
- @classmethod
118
- def from_pretrained(cls, load_directory):
119
- with open(os.path.join(load_directory, "config.json")) as f:
120
- config = json.load(f)
121
-
122
- model = cls(**config)
123
-
124
- # Load weights
125
- state_dict = load_file(os.path.join(load_directory, "model.safetensors"))
126
- model.load_state_dict(state_dict)
127
-
128
- # Load tokenizer if available
129
- tokenizer_path = os.path.join(load_directory, "tokenizer.pkl")
130
- if os.path.exists(tokenizer_path):
131
- tokenizer = Tokenizer()
132
- tokenizer.load(tokenizer_path)
133
- model.tokenizer = tokenizer
134
-
135
- return model
136
-
137
- def print_architecture(self, inputs=None):
138
- parser = argparse.ArgumentParser()
139
- parser.add_argument("--model_path", type=str, required=True, help="Path to trained transformer model")
140
- parser.add_argument("--json", type=str, default="SMB1_LevelsAndCaptions-regular-test.json", help="Path to dataset json file")
141
- parser.add_argument("--num_samples", type=int, default=10, help="Number of captions to evaluate")
142
- parser.add_argument("--mask_prob", type=float, default=0.15, help="Probability of masking each token")
143
-
144
- parser.add_argument("--compare_checkpoints", action="store_true", default=False, help="Run comparison across all model checkpoints")
145
- args = parser.parse_args()
146
-
147
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148
- model = TransformerModel.from_pretrained(args.model_path).to(device)
149
- print(f"Loaded model from {args.model_path}")
150
-
151
- import os
152
- import re
153
- import json
154
- import matplotlib.pyplot as plt
155
- from torchview import draw_graph
156
- import graphviz
157
-
158
- graph = draw_graph(
159
- model=model,
160
- input_data=inputs,
161
- expand_nested=False,
162
- #enable_output_shape=True,
163
- #roll_out="nested",
164
- depth=1
165
- )
166
-
167
- # Save plot
168
- filename = 'mlm_architecture'
169
- graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
170
- #graph.visual_graph.save('unet_architecture.dot')
171
-
172
- def save_architecture_pdf(self, filename="transformer_architecture.pdf", input_length=32):
173
- """Save a visualization of the model architecture as a PDF using torchview."""
174
- try:
175
- from torchview import draw_graph
176
- except ImportError:
177
- raise ImportError("torchview is required for model visualization. Install with 'pip install torchview'.")
178
- import torch
179
- import os
180
- # Create a dummy input of the correct type for the model
181
- captions = ["full floor. two coins. one pipe.", "floor with two gaps. one cannon. many enemies."]
182
- tensor = encode_token_captions(captions, self.tokenizer, self.max_seq_length, device=next(self.parameters()).device)
183
- input_length = tensor.size(1) if tensor.dim() > 1 else self.max_seq_length
184
-
185
- num_tokens_list = [len(self.tokenizer.encode(c)) for c in captions]
186
- input_length = max(num_tokens_list) if num_tokens_list else input_length
187
- dummy_input = torch.zeros((1, input_length), dtype=torch.long, device=next(self.parameters()).device)
188
-
189
- # Draw the graph and save as PNG
190
- graph = draw_graph(self, input_data=dummy_input, expand_nested=True, save_graph=True, filename=filename.replace('.pdf',''), directory=".", depth=2)
191
- png_file = filename.replace('.pdf', '.png')
192
- # Convert PNG to PDF
193
- if os.path.exists(png_file):
194
- try:
195
- from PIL import Image
196
- im = Image.open(png_file)
197
- im.save(filename, "PDF", resolution=100.0)
198
- print(f"Saved architecture PDF to {filename}")
199
- # Optionally, remove the PNG file
200
- os.remove(png_file)
201
- except ImportError:
202
- print(f"PIL not installed. Architecture saved as PNG: {png_file}")
203
- except Exception as e:
204
- print(f"Could not convert PNG to PDF: {e}")
205
- else:
206
- print(f"Could not find PNG file to convert: {png_file}")