schrum2 commited on
Commit
a09cfc1
·
verified ·
1 Parent(s): ef6bba8

Loading into root will supposedly make them easier to find

Browse files
caption_match.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from create_ascii_captions import assign_caption
2
+
3
+ # Quantity order for scoring partial matches
4
+ QUANTITY_TERMS = ["one", "two", "a few", "several", "many"]
5
+
6
+ # Topics to compare
7
+ TOPIC_KEYWORDS = [
8
+ #"giant gap", # I think all gaps are subsumed by the floor topic
9
+ "floor", "ceiling",
10
+ "broken pipe", "upside down pipe", "pipe",
11
+ "coin line", "coin",
12
+ "platform", "tower", #"wall",
13
+ "broken cannon", "cannon",
14
+ "ascending staircase", "descending staircase",
15
+ "rectangular",
16
+ "irregular",
17
+ "question block", "loose block",
18
+ "enem" # catch "enemy"/"enemies"
19
+ ]
20
+
21
+ # Need list because the order matters
22
+ KEYWORD_TO_NEGATED_PLURAL = [
23
+ (" broken pipe.", ""), # If not the first phrase
24
+ ("broken pipe. ", ""), # If the first phrase (after removing all others)
25
+ (" broken cannon.", ""), # If not the first phrase
26
+ ("broken cannon. ", ""), # If the first phrase (after removing all others)
27
+ ("pipe", "pipes"),
28
+ ("cannon", "cannons"),
29
+ ("platform", "platforms"),
30
+ ("tower", "towers"),
31
+ ("staircase", "staircases"),
32
+ ("enem", "enemies"),
33
+ ("rectangular", "rectangular block clusters"),
34
+ ("irregular", "irregular block clusters"),
35
+ ("coin line", "coin lines"),
36
+ ("coin.", "coins."), # Need period to avoid matching "coin line"
37
+ ("question block", "question blocks"),
38
+ ("loose block", "loose blocks")
39
+ ]
40
+
41
+ BROKEN_TOPICS = 2 # Number of topics that are considered "broken" (e.g., "broken pipe", "broken cannon")
42
+
43
+ # Plural normalization map (irregulars)
44
+ PLURAL_EXCEPTIONS = {
45
+ "enemies": "enemy",
46
+ }
47
+
48
+ def normalize_plural(phrase):
49
+ # Normalize known irregular plurals
50
+ for plural, singular in PLURAL_EXCEPTIONS.items():
51
+ phrase = phrase.replace(plural, singular)
52
+
53
+ # Normalize regular plurals (basic "s" endings)
54
+ words = phrase.split()
55
+ normalized_words = []
56
+ for word in words:
57
+ if word.endswith('s') and not word.endswith('ss'): # avoid "class", "boss"
58
+ singular = word[:-1]
59
+ normalized_words.append(singular)
60
+ else:
61
+ normalized_words.append(word)
62
+ return ' '.join(normalized_words)
63
+
64
+ def extract_phrases(caption, debug=False):
65
+ phrases = [phrase.strip() for phrase in caption.split('.') if phrase.strip()]
66
+ topic_to_phrase = {}
67
+ already_matched_phrases = set() # Track phrases that have been matched
68
+
69
+ for topic in TOPIC_KEYWORDS:
70
+ matching_phrases = []
71
+
72
+ for p in phrases:
73
+ # Only consider phrases that haven't been matched to longer topics
74
+ if topic in p and p not in already_matched_phrases:
75
+ matching_phrases.append(p)
76
+
77
+ if matching_phrases:
78
+ # Filter out "no ..." phrases as equivalent to absence
79
+ phrase = matching_phrases[0]
80
+ if phrase.lower().startswith("no "):
81
+ topic_to_phrase[topic] = None
82
+ if debug:
83
+ print(f"[Extract] Topic '{topic}': detected 'no ...', treating as None")
84
+ else:
85
+ topic_to_phrase[topic] = phrase
86
+ already_matched_phrases.add(phrase) # Mark this phrase as matched
87
+ if debug:
88
+ print(f"[Extract] Topic '{topic}': found phrase '{phrase}'")
89
+ else:
90
+ topic_to_phrase[topic] = None
91
+ if debug:
92
+ print(f"[Extract] Topic '{topic}': no phrase found")
93
+
94
+ return topic_to_phrase
95
+
96
+ def quantity_score(phrase1, phrase2, debug=False):
97
+ def find_quantity(phrase):
98
+ for term in QUANTITY_TERMS:
99
+ if term in phrase:
100
+ return term
101
+ return None
102
+
103
+ qty1 = find_quantity(phrase1)
104
+ qty2 = find_quantity(phrase2)
105
+
106
+ if debug:
107
+ print(f"[Quantity] Comparing quantities: '{qty1}' vs. '{qty2}'")
108
+
109
+ if qty1 and qty2:
110
+ idx1 = QUANTITY_TERMS.index(qty1)
111
+ idx2 = QUANTITY_TERMS.index(qty2)
112
+ diff = abs(idx1 - idx2)
113
+ max_diff = len(QUANTITY_TERMS) - 1
114
+ score = 1.0 - (diff / max_diff)
115
+ if debug:
116
+ print(f"[Quantity] Quantity indices: {idx1} vs. {idx2}, diff: {diff}, score: {score:.2f}")
117
+ return score
118
+ if debug:
119
+ print("[Quantity] At least one quantity missing, assigning partial score 0.1")
120
+ return 0.1
121
+
122
+ def compare_captions(correct_caption, generated_caption, debug=False, return_matches=False):
123
+ correct_phrases = extract_phrases(correct_caption, debug=debug)
124
+ generated_phrases = extract_phrases(generated_caption, debug=debug)
125
+
126
+ total_score = 0.0
127
+ num_topics = len(TOPIC_KEYWORDS)
128
+
129
+ exact_matches = []
130
+ partial_matches = []
131
+ excess_phrases = []
132
+
133
+ if debug:
134
+ print("\n--- Starting Topic Comparison ---\n")
135
+
136
+ for topic in TOPIC_KEYWORDS:
137
+ correct = correct_phrases[topic]
138
+ generated = generated_phrases[topic]
139
+
140
+ if debug:
141
+ print(f"[Topic: {topic}] Correct: {correct} | Generated: {generated}")
142
+
143
+ if correct is None and generated is None:
144
+ total_score += 1.0
145
+ if debug:
146
+ print(f"[Topic: {topic}] Both None — full score: 1.0\n")
147
+ elif correct is None or generated is None:
148
+ total_score += -1.0
149
+ if generated is not None: # Considered an excess phrase
150
+ excess_phrases.append(generated)
151
+ if debug:
152
+ print(f"[Topic: {topic}] One is None — penalty: -1.0\n")
153
+ else:
154
+ # Normalize pluralization before comparison
155
+ norm_correct = normalize_plural(correct)
156
+ norm_generated = normalize_plural(generated)
157
+
158
+ if debug:
159
+ print(f"[Topic: {topic}] Normalized: Correct: '{norm_correct}' | Generated: '{norm_generated}'")
160
+
161
+ if norm_correct == norm_generated:
162
+ total_score += 1.0
163
+ exact_matches.append(generated)
164
+ if debug:
165
+ print(f"[Topic: {topic}] Exact match — score: 1.0\n")
166
+ elif any(term in norm_correct for term in QUANTITY_TERMS) and any(term in norm_generated for term in QUANTITY_TERMS):
167
+ qty_score = quantity_score(norm_correct, norm_generated, debug=debug)
168
+ total_score += qty_score
169
+ partial_matches.append(generated)
170
+ if debug:
171
+ print(f"[Topic: {topic}] Quantity-based partial score: {qty_score:.2f}\n")
172
+ else:
173
+ total_score += 0.1
174
+ partial_matches.append(generated)
175
+ if debug:
176
+ print(f"[Topic: {topic}] Partial match (topic overlap) — score: 0.1\n")
177
+
178
+ if debug:
179
+ print(f"[Topic: {topic}] Current total score: {total_score:.4f}\n")
180
+
181
+ if debug:
182
+ print("total_score before normalization:", total_score)
183
+ print(f"Number of topics: {num_topics}")
184
+
185
+ final_score = total_score / num_topics
186
+ if debug:
187
+ print(f"--- Final score: {final_score:.4f} ---\n")
188
+
189
+ if return_matches:
190
+ return final_score, exact_matches, partial_matches, excess_phrases
191
+
192
+ return final_score
193
+
194
+ def process_scene_segments(scene, segment_width, prompt, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence, verbose=False):
195
+ """
196
+ Process a scene by partitioning it into segments, assigning captions, and computing comparison scores.
197
+
198
+ Args:
199
+ scene (list): The scene to process, represented as a 2D list.
200
+ segment_width (int): The width of each segment.
201
+ prompt (str): The prompt to compare captions against.
202
+ id_to_char (dict): Mapping from tile IDs to characters.
203
+ char_to_id (dict): Mapping from characters to tile IDs.
204
+ tile_descriptors (dict): Descriptions of individual tile types.
205
+ describe_locations (bool): Whether to include location descriptions in captions.
206
+ describe_absence (bool): Whether to indicate absence of items in captions.
207
+ verbose (bool): If True, print captions and scores for each segment.
208
+
209
+ Returns:
210
+ tuple: A tuple containing the average comparison score, captions for each segment, and scores for each segment.
211
+ """
212
+ # Partition the scene into segments of the specified width
213
+ segments = [
214
+ [row[i:i+segment_width] for row in scene] # Properly slice each row of the scene
215
+ for i in range(0, len(scene[0]), segment_width)
216
+ ]
217
+
218
+ # Assign captions and compute scores for each segment
219
+ segment_scores = []
220
+ segment_captions = []
221
+ for idx, segment in enumerate(segments):
222
+ segment_caption = assign_caption(segment, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence)
223
+ segment_score = compare_captions(prompt, segment_caption)
224
+ segment_scores.append(segment_score)
225
+ segment_captions.append(segment_caption)
226
+
227
+ if verbose:
228
+ print(f"Segment {idx + 1} caption: {segment_caption}")
229
+ print(f"Segment {idx + 1} comparison score: {segment_score}")
230
+
231
+ # Compute the average comparison score
232
+ average_score = sum(segment_scores) / len(segment_scores) if segment_scores else 0
233
+
234
+ if verbose:
235
+ print(f"Average comparison score across all segments: {average_score}")
236
+
237
+ return average_score, segment_captions, segment_scores
238
+
239
+ if __name__ == '__main__':
240
+
241
+ ref = "floor with one gap. two enemies. one platform. one tower."
242
+ gen = "giant gap with one chunk of floor. two enemies. one platform. one tower."
243
+
244
+ score = compare_captions(ref, gen, debug=True)
245
+ print(f"Should be: {ref}")
246
+ print(f" but was: {gen}")
247
+ print(f"Score: {score}")
common_settings.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ NUM_INFERENCE_STEPS = 30
3
+ GUIDANCE_SCALE = 7.5
4
+
5
+ MARIO_HEIGHT = 16
6
+ MARIO_WIDTH = 16
7
+
8
+ MARIO_TILE_PIXEL_DIM = 16
9
+ MARIO_TILE_COUNT = 13
10
+
11
+ LR_HEIGHT = 32
12
+ LR_WIDTH = 32
13
+
14
+ LR_TILE_PIXEL_DIM = 8
15
+ LR_TILE_COUNT = 8
16
+
17
+ MEGAMAN_HEIGHT = 14
18
+ MEGAMAN_WIDTH = 16
general_training_helper.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from level_dataset import LevelDataset
3
+ import random
4
+ from plotter import Plotter
5
+ from datetime import datetime
6
+ import os
7
+ import threading
8
+ import json
9
+ import torch.nn.functional as F
10
+ import torch
11
+
12
+
13
+
14
+
15
+ def create_dataloaders(json_path, val_json, tokenizer, data_mode, augment, num_tiles,
16
+ negative_prompt_training, block_embeddings, batch_size):
17
+ # Initialize dataset
18
+ train_dataset = LevelDataset(
19
+ json_path=json_path,
20
+ tokenizer=tokenizer,
21
+ shuffle=True,
22
+ mode=data_mode,
23
+ augment=augment,
24
+ num_tiles=num_tiles,
25
+ negative_captions=negative_prompt_training,
26
+ block_embeddings=block_embeddings
27
+ )
28
+ val_dataset = None
29
+ if val_json is not None:
30
+ val_dataset = LevelDataset(
31
+ json_path=val_json,
32
+ tokenizer=tokenizer,
33
+ shuffle=False,
34
+ mode=data_mode,
35
+ augment=False,
36
+ num_tiles=num_tiles,
37
+ negative_captions=negative_prompt_training,
38
+ block_embeddings=block_embeddings
39
+ )
40
+
41
+ # Create dataloader
42
+ train_dataloader = DataLoader(
43
+ train_dataset,
44
+ batch_size=batch_size,
45
+ shuffle=True,
46
+ num_workers=4,
47
+ drop_last=True,
48
+ persistent_workers=True
49
+ )
50
+
51
+ val_dataloader = None
52
+ if val_dataset is not None:
53
+ val_dataloader = DataLoader(
54
+ val_dataset,
55
+ batch_size=batch_size,
56
+ shuffle=False,
57
+ num_workers=4,
58
+ drop_last=False,
59
+ persistent_workers=True
60
+ )
61
+
62
+ return train_dataloader, val_dataloader
63
+
64
+
65
+ def get_random_training_samples(train_dataloader, negative_prompt_training, output_dir = None):
66
+ train_dataset = train_dataloader.dataset
67
+ # Sample four random captions from the dataset
68
+ sample_indices = [random.randint(0, len(train_dataset) - 1) for _ in range(4)]
69
+
70
+ sample_captions = [train_dataset[i][1] for i in sample_indices]
71
+ print("Sample captions:")
72
+ for caption in sample_captions:
73
+ print(caption)
74
+
75
+ sample_negative_captions = ""
76
+ if negative_prompt_training:
77
+ sample_negative_captions = [train_dataset[i][2] for i in sample_indices]
78
+ print("Sample negative captions:")
79
+ for caption in sample_negative_captions:
80
+ print(f" NEG: {caption}")
81
+
82
+ #Write captions to a file
83
+ if output_dir is not None:
84
+ os.makedirs(output_dir, exist_ok=True)
85
+ out_path = os.path.join(output_dir, "sample_captions.txt")
86
+ with open(out_path, "w", encoding="utf-8") as f:
87
+ f.write("Sample captions:\n")
88
+ for caption in sample_captions:
89
+ f.write(str(caption) + "\n")
90
+ if negative_prompt_training:
91
+ f.write("\nSample negative captions:\n")
92
+ for caption in sample_negative_captions:
93
+ f.write(str(caption) + "\n")
94
+ print(f"Sample captions written to {out_path}")
95
+
96
+
97
+ return sample_captions, sample_negative_captions
98
+
99
+
100
+ def start_plotter(log_file, output_dir, left_key, right_key, left_label, right_label, png_name):
101
+ formatted_date = datetime.now().strftime(r'%Y%m%d-%H%M%S')
102
+
103
+ plotter = Plotter(log_file, update_interval=5.0, left_key=left_key, right_key=right_key,
104
+ left_label=left_label, right_label=right_label, output_png=f'{png_name}_{formatted_date}.png')
105
+ plot_thread = threading.Thread(target=plotter.start_plotting)
106
+ plot_thread.daemon = True
107
+ plot_thread.start()
108
+ print(f"{png_name} plotting enabled. Progress will be saved to {os.path.join(output_dir, f'{png_name}_{formatted_date}.png')}")
109
+ return plotter, plot_thread
110
+
111
+
112
+ def kill_plotter(plotter, plot_thread):
113
+ if plot_thread and plot_thread.is_alive():
114
+ plotter.stop_plotting()
115
+ plot_thread.join(timeout=5.0)
116
+ if plot_thread.is_alive():
117
+ print("Warning: Plot thread did not terminate properly")
118
+
119
+
120
+ def load_config_from_json(config_path):
121
+ """Load hyperparameters from a JSON config file."""
122
+ try:
123
+ with open(config_path, 'r') as f:
124
+ config = json.load(f)
125
+ print(f"Configuration loaded from {config_path}")
126
+
127
+ # Print the loaded config for verification
128
+ print("Loaded hyperparameters:")
129
+ for key, value in config.items():
130
+ print(f" {key}: {value}")
131
+
132
+ return config
133
+ except (json.JSONDecodeError, FileNotFoundError) as e:
134
+ print(f"Error loading config file: {e}")
135
+ raise e
136
+
137
+
138
+ def update_args_from_config(args, config):
139
+ """Update argparse namespace with values from config."""
140
+ # Convert config dict to argparse namespace
141
+ for key, value in config.items():
142
+ if hasattr(args, key):
143
+ setattr(args, key, value)
144
+ return args
145
+
146
+
147
+ def get_scene_from_embeddings(image, block_embeddings):
148
+ """Code copied over from level_dataset, should give limited support for block embeddings"""
149
+ # Reshape sample to [batch_size * height * width, embedding_dim]
150
+ batch_size, embedding_dim, height, width = image.shape
151
+
152
+ flat_samples = image.permute(0, 2, 3, 1).reshape(-1, embedding_dim)
153
+
154
+ # Normalize vectors for cosine similarity
155
+ flat_samples = F.normalize(flat_samples, p=2, dim=1).cpu()
156
+ block_embeddings = F.normalize(block_embeddings, p=2, dim=1)
157
+
158
+ # Calculate cosine similarity between each position and all tile embeddings
159
+ similarities = torch.matmul(flat_samples, block_embeddings.t())
160
+
161
+ # Get indices of most similar tiles
162
+ indices = torch.softmax(similarities, dim=1)
163
+
164
+
165
+ # Reshape back to [batch_size, height, width]
166
+ indices = indices.reshape(batch_size, height, width, 13)
167
+ indices = indices.permute(0, 3, 1, 2)
168
+
169
+ image=indices.detach().cpu()
170
+ return image
171
+
172
+
latent_diffusion_pipeline.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DDPMPipeline
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Union, List, Tuple
5
+ from diffusers.utils.torch_utils import randn_tensor
6
+ from diffusers.pipelines.ddpm.pipeline_ddpm import ImagePipelineOutput
7
+ import common_settings as common_settings
8
+ import os
9
+ import json
10
+ from general_training_helper import get_scene_from_embeddings
11
+
12
+ class UnconditionalDDPMPipeline(DDPMPipeline):
13
+ def __init__(self, unet, scheduler, block_embeddings=None):
14
+ super().__init__(unet, scheduler)
15
+
16
+ self.block_embeddings = block_embeddings
17
+
18
+
19
+ def save_pretrained(self, save_directory):
20
+ os.makedirs(save_directory, exist_ok=True)
21
+ super().save_pretrained(save_directory)
22
+ # Save block_embeddings tensor if it exists
23
+ if self.block_embeddings is not None:
24
+ torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))
25
+
26
+ @classmethod
27
+ def from_pretrained(cls, pretrained_model_path, **kwargs):
28
+ pipeline = super().from_pretrained(pretrained_model_path, **kwargs)
29
+ # Load block_embeddings tensor if it exists
30
+ block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
31
+ if os.path.exists(block_embeds_path):
32
+ pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
33
+ else:
34
+ pipeline.block_embeddings = None
35
+ return pipeline
36
+
37
+
38
+
39
+ def give_sprite_scaling_factors(self, sprite_scaling_factors):
40
+ """
41
+ Set the sprite scaling factors for the pipeline.
42
+ This is used to apply per-sprite temperature scaling during inference.
43
+ """
44
+ self.sprite_scaling_factors = sprite_scaling_factors
45
+
46
+ def __call__(
47
+ self,
48
+ batch_size: int = 1,
49
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
50
+ num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,
51
+ output_type: Optional[str] = "tensor",
52
+ return_dict: bool = True,
53
+ height: int = common_settings.MARIO_HEIGHT, width: int = common_settings.MARIO_WIDTH,
54
+ latents: Optional[torch.FloatTensor] = None,
55
+ show_progress_bar=True,
56
+ ) -> Union[ImagePipelineOutput, Tuple]:
57
+
58
+ self.unet.eval()
59
+ with torch.no_grad():
60
+
61
+ if latents is not None:
62
+ image = latents.to(self.device)
63
+ else:
64
+ image_shape = (
65
+ batch_size,
66
+ self.unet.config.in_channels,
67
+ height,
68
+ width
69
+ )
70
+
71
+ image = torch.randn(image_shape, generator=generator, device=self.device)
72
+
73
+ self.scheduler.set_timesteps(num_inference_steps)
74
+
75
+ iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
76
+ for t in iterator:
77
+ #print(image.shape)
78
+ model_output = self.unet(image, t).sample
79
+ image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
80
+
81
+ # Apply per-sprite temperature scaling if enabled
82
+ if hasattr(self,"sprite_scaling_factors") and self.sprite_scaling_factors is not None:
83
+ image = image / self.sprite_scaling_factors.view(1, -1, 1, 1)
84
+
85
+
86
+ if self.block_embeddings is not None:
87
+ image = get_scene_from_embeddings(image, self.block_embeddings)
88
+ else:
89
+ image = F.softmax(image, dim=1)
90
+ image = image.detach().cpu()
91
+
92
+ if not return_dict:
93
+ return (image,)
94
+
95
+ return ImagePipelineOutput(images=image)
96
+
97
+ def print_unet_architecture(self):
98
+ """Prints the architecture of the UNet model."""
99
+ print(self.unet)
model_index.json CHANGED
@@ -6,7 +6,7 @@
6
  "DDPMScheduler"
7
  ],
8
  "text_encoder": [
9
- "models.text_model",
10
  "TransformerModel"
11
  ],
12
  "tokenizer": [
 
6
  "DDPMScheduler"
7
  ],
8
  "text_encoder": [
9
+ "text_model",
10
  "TransformerModel"
11
  ],
12
  "tokenizer": [
naming_conventions.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name_map = [
2
+ ("Mar1and2-conditional-regular", "MLM-regular"),
3
+ ("Mar1and2-conditional-absence", "MLM-absence"),
4
+ ("Mar1and2-conditional-negative", "MLM-negative"),
5
+ ("Mar1and2-conditional-MiniLM-regular", "MiniLM-single-regular"),
6
+ ("Mar1and2-conditional-MiniLM-absence", "MiniLM-single-absence"),
7
+ ("Mar1and2-conditional-MiniLM-negative", "MiniLM-single-negative"),
8
+ ("Mar1and2-conditional-MiniLMsplit-regular", "MiniLM-multiple-regular"),
9
+ ("Mar1and2-conditional-MiniLMsplit-absence", "MiniLM-multiple-absence"),
10
+ ("Mar1and2-conditional-MiniLMsplit-negative", "MiniLM-multiple-negative"),
11
+ ("Mar1and2-conditional-GTE-regular", "GTE-single-regular"),
12
+ ("Mar1and2-conditional-GTE-absence", "GTE-single-absence"),
13
+ ("Mar1and2-conditional-GTE-negative", "GTE-single-negative"),
14
+ ("Mar1and2-conditional-GTEsplit-regular", "GTE-multiple-regular"),
15
+ ("Mar1and2-conditional-GTEsplit-absence", "GTE-multiple-absence"),
16
+ ("Mar1and2-conditional-GTEsplit-negative", "GTE-multiple-negative"),
17
+ ("Mar1and2-fdm-MiniLM-regular", "FDM-MiniLM-regular"),
18
+ ("Mar1and2-fdm-MiniLM-absence", "FDM-MiniLM-absence"),
19
+ ("Mar1and2-fdm-GTE-regular", "FDM-GTE-regular"),
20
+ ("Mar1and2-fdm-GTE-absence", "FDM-GTE-absence"),
21
+ ("Mar1and2-wgan", "WGAN"),
22
+ ("Mar1and2-unconditional", "Unconditional"),
23
+ ("MarioGPT_metrics", "MarioGPT"),
24
+ ]
25
+
26
+ def get_model_name_map_and_order():
27
+ mapping = dict(model_name_map)
28
+ order = [v for k, v in model_name_map]
29
+ return mapping, order
pipeline_loader.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from text_diffusion_pipeline import TextConditionalDDPMPipeline
2
+ from latent_diffusion_pipeline import UnconditionalDDPMPipeline
3
+ import os
4
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
5
+
6
+
7
+ def get_pipeline(model_path):
8
+ # If model_path is a local directory, use the original logic
9
+ if os.path.isdir(model_path):
10
+ #Diffusion models
11
+ if os.path.exists(os.path.join(model_path, "unet")):
12
+ if os.path.exists(os.path.join(model_path, "text_encoder")):
13
+ #If it has a text encoder and a unet, it's text conditional diffusion
14
+ pipe = TextConditionalDDPMPipeline.from_pretrained(model_path)
15
+ else:
16
+ #If it has no text encoder, use the unconditional diffusion model
17
+ pipe = UnconditionalDDPMPipeline.from_pretrained(model_path)
18
+ else:
19
+ # Assume it's a Hugging Face Hub model ID
20
+ # Try to load config to determine if it's text-conditional
21
+ try:
22
+ config, _ = DiffusionPipeline.load_config(model_path)
23
+ components = config.get("components", {})
24
+ except Exception:
25
+ components = {}
26
+ if "text_encoder" in components or "text_encoder" in str(components):
27
+ # Use the local pipeline file for custom_pipeline
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ model_path,
30
+ custom_pipeline="models.text_diffusion_pipeline.TextConditionalDDPMPipeline",
31
+ trust_remote_code=True,
32
+ )
33
+ else:
34
+ # Fallback: try unconditional
35
+ pipe = DiffusionPipeline.from_pretrained(
36
+ model_path,
37
+ custom_pipeline="models.latent_diffusion_pipeline.UnconditionalDDPMPipeline",
38
+ trust_remote_code=True,
39
+ )
40
+
41
+ return pipe
plotter.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Track changes in loss and learning rate during execution
2
+ import argparse
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ import os
6
+ import time
7
+ import json
8
+ import tempfile
9
+ import shutil
10
+ from pathlib import Path
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(description="Train a text-conditional diffusion model for tile-based level generation")
15
+
16
+ # Dataset args
17
+ parser.add_argument("--log_file", type=str, default=None, help="The the filepath of the file to get the data from")
18
+ parser.add_argument("--left_key", type=str, default=None, help="The key for the left y-axis")
19
+ parser.add_argument("--right_key", type=str, default=None, help="The key for the right y-axis")
20
+ parser.add_argument("--left_label", type=str, default=None, help="The label for the left y-axis")
21
+ parser.add_argument("--right_label", type=str, default=None, help="The label for the right y-axis")
22
+ parser.add_argument("--output_png", type=str, default="output.png", help="The output png file")
23
+ parser.add_argument("--update_interval", type=int, default=1.0, help="The update inteval in epochs")
24
+ parser.add_argument("--start_point", type=int, default=None, help="The start point for the plot")
25
+
26
+ return parser.parse_args()
27
+
28
+
29
+ def main():
30
+ args = parse_args()
31
+
32
+ log_file = args.log_file
33
+ left_key = args.left_key
34
+ right_key = args.right_key
35
+ left_label = args.left_label
36
+ right_label = args.right_label
37
+ output_png = args.output_png
38
+ update_interval = args.update_interval
39
+ start_point = args.start_point
40
+
41
+ general_update_plot(log_file, left_key, right_key, left_label, right_label, output_png, update_interval=update_interval, startPoint=start_point)
42
+
43
+
44
+ def general_update_plot(log_file, left_key, right_key, left_label, right_label, output_png, update_interval=1.0, startPoint=None):
45
+ log_dir = os.path.dirname(log_file)
46
+
47
+ # Create figure here and ensure it's closed
48
+ fig = plt.figure(figsize=(10, 6))
49
+ ax = fig.add_subplot(111)
50
+
51
+ try:
52
+ if os.path.exists(log_file):
53
+ with open(log_file, 'r') as f:
54
+ data = [json.loads(line) for line in f if line.strip()]
55
+
56
+ if not data:
57
+ return
58
+
59
+ if startPoint is not None:
60
+ data = [entry for entry in data if entry.get('epoch', 0) >= startPoint]
61
+
62
+ if not data:
63
+ return
64
+
65
+ epochs = [entry.get('epoch', 0) for entry in data]
66
+ left = [entry.get(left_key, 0) for entry in data]
67
+
68
+ # For right axis (e.g., lr), only include points where right_key exists
69
+ right_points = [(entry.get('epoch', 0), entry.get(right_key))
70
+ for entry in data if right_key in entry]
71
+ if right_points:
72
+ right_epochs, right_values = zip(*right_points)
73
+ else:
74
+ right_epochs, right_values = [], []
75
+
76
+ # Clear axis
77
+ ax.clear()
78
+
79
+ # Plot both metrics on the same axis
80
+ ax.plot(epochs, left, 'b-', label=left_label)
81
+ if right_epochs:
82
+ ax.plot(right_epochs, right_values, 'r-', label=right_label)
83
+
84
+ ax.set_xlabel('Epoch')
85
+ ax.set_ylabel(left_label) # "Loss" as y-axis label
86
+ ax.set_title('Training Progress')
87
+ ax.legend(loc='upper left')
88
+ #Limit x-axis to startPoint if provided
89
+ if startPoint is not None:
90
+ ax.set_xlim(left=startPoint)
91
+ fig.tight_layout()
92
+
93
+ # Use the stored base directory instead of getting it from log_file
94
+ if os.path.isabs(output_png) or os.path.dirname(output_png):
95
+ output_path = output_png
96
+ else:
97
+ output_path = os.path.join(log_dir, output_png)
98
+
99
+ save_figure_safely(fig, output_path)
100
+ finally:
101
+ plt.close(fig) # Ensure figure is closed even if an error occurs
102
+
103
+ def save_figure_safely(fig, output_path):
104
+ """Save figure to a temporary file first, then move it to the final location"""
105
+ output_path = str(Path(output_path)) # Convert to string path
106
+
107
+ # Create temporary file with .png extension
108
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
109
+ tmp_path = tmp_file.name
110
+
111
+ try:
112
+ # Save to temporary file
113
+ fig.savefig(tmp_path)
114
+
115
+ # Create output directory if it doesn't exist
116
+ os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
117
+
118
+ # Try to move the file to final destination
119
+ # If move fails, try to copy and then delete
120
+ try:
121
+ shutil.move(tmp_path, output_path)
122
+ except OSError:
123
+ shutil.copy2(tmp_path, output_path)
124
+ os.unlink(tmp_path)
125
+ except Exception as e:
126
+ # Clean up temporary file if anything goes wrong
127
+ if os.path.exists(tmp_path):
128
+ os.unlink(tmp_path)
129
+ raise e
130
+
131
+ class Plotter:
132
+ def __init__(self, log_file, update_interval=1.0, left_key='loss', right_key='lr',
133
+ left_label='Loss', right_label='Learning Rate', output_png='training_progress.png'):
134
+ self.log_dir = os.path.dirname(log_file)
135
+ self.log_file = log_file
136
+ self.update_interval = update_interval
137
+ self.running = True
138
+ self.output_png = output_png
139
+ self.left_key = left_key
140
+ self.right_key = right_key
141
+ self.left_label = left_label
142
+ self.right_label = right_label
143
+
144
+ matplotlib.use('Agg')
145
+
146
+ def __enter__(self):
147
+ return self
148
+
149
+ def __exit__(self, exc_type, exc_val, exc_tb):
150
+ self.stop_plotting()
151
+
152
+ def __del__(self):
153
+ self.stop_plotting()
154
+
155
+ def update_plot(self):
156
+ general_update_plot(self.log_file, self.left_key, self.right_key,
157
+ self.left_label, self.right_label, self.output_png,
158
+ update_interval=self.update_interval)
159
+
160
+ def start_plotting(self):
161
+ print("Starting plotting in background")
162
+ while self.running:
163
+ self.update_plot()
164
+ time.sleep(self.update_interval)
165
+
166
+ def stop_plotting(self):
167
+ if hasattr(self, 'running'): # Check if already stopped
168
+ self.running = False
169
+ self.update_plot()
170
+ print("Plotting stopped")
171
+
172
+ if __name__ == "__main__":
173
+ main()
sampler.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import subprocess
8
+ import tempfile
9
+
10
+ import numpy as np
11
+ import torch
12
+ from PIL.Image import Image
13
+ from tqdm import tqdm
14
+ from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper
15
+
16
+
17
+ from mario_gpt.lm.base import BaseMarioLM
18
+ from mario_gpt.prompter import Prompter
19
+ from mario_gpt.simulator import Simulator
20
+ from mario_gpt.utils import (
21
+ convert_level_to_png,
22
+ load_level,
23
+ save_level,
24
+ trim_level,
25
+ view_level,
26
+ )
27
+
28
+ def scene_to_ascii(scene, id_to_char, shorten: bool = True) -> List[str]:
29
+ """
30
+ Convert JSON scene files from a list of lists of ints
31
+ to a list of ASCII strings using id_to_char mapping.
32
+ If shorten is True, only the last 15 rows are kept.
33
+ Args:
34
+ scene: List[List[int]] - 2D array of tile IDs
35
+ id_to_char: Dict[int, str] - mapping from tile ID to ASCII character
36
+ shorten: bool - If True, will shorten the output to only include the first 15 rows
37
+ so A* Mario (for SNES graphics) to run without glitching
38
+ Returns:
39
+ List[str]: List of strings, each representing a row in ASCII
40
+ """
41
+ if shorten and len(scene) > 15:
42
+ scene = scene[-15:] # Keep only the last 15 rows
43
+ return ["".join(id_to_char[num] for num in row) for row in scene]
44
+
45
+ @dataclass
46
+ class SampleOutput:
47
+ level: Optional[List[str]]
48
+ prompt: Optional[str] = None
49
+ img: Optional[Image] = None
50
+ sample_predictions_str: Optional[List[str]] = None
51
+ sample_predictions_img: Optional[Image] = None
52
+ level_tensor: Optional[torch.Tensor] = None
53
+ sample_predictions_tensor: Optional[torch.Tensor] = None
54
+ # Uses MarioEval graphics for rendering levels when True
55
+ use_snes_graphics: bool = False
56
+
57
+ @classmethod
58
+ def create(
59
+ cls,
60
+ level_tensor: torch.Tensor,
61
+ sample_predictions_tensor: torch.Tensor,
62
+ tokenizer,
63
+ prompter: Optional[Prompter] = None,
64
+ ) -> SampleOutput:
65
+ # batch = 1
66
+ level = None
67
+ img = None
68
+
69
+ try:
70
+ level = view_level(level_tensor, tokenizer)
71
+ img = convert_level_to_png(level)[0]
72
+ except Exception as e:
73
+ print(
74
+ f"Failed to generate string or image representation for full level! Got error {e}"
75
+ )
76
+ level = None
77
+ img = None
78
+ try:
79
+ sample_predictions_str = view_level(sample_predictions_tensor, tokenizer)
80
+ sample_predictions_img = convert_level_to_png(sample_predictions_str)[0]
81
+ except Exception as e:
82
+ print(
83
+ f"Failed to generate string or image representation for sampled predictions! Got error {e}"
84
+ )
85
+ sample_predictions_str = None
86
+ sample_predictions_img = None
87
+
88
+ prompt = None
89
+ if prompter is not None:
90
+ prompt = prompter(level_tensor)[0]
91
+
92
+ return SampleOutput(
93
+ level,
94
+ prompt,
95
+ img,
96
+ sample_predictions_str,
97
+ sample_predictions_img,
98
+ level_tensor,
99
+ sample_predictions_tensor,
100
+ )
101
+
102
+ @classmethod
103
+ def from_level_predictions(
104
+ cls,
105
+ level: torch.Tensor,
106
+ sample_predictions: torch.Tensor,
107
+ tokenizer,
108
+ prompter: Optional[Prompter] = None,
109
+ ) -> Union[SampleOutput, List[SampleOutput]]:
110
+ level_tensor = trim_level(level).squeeze().detach().cpu()
111
+ sample_predictions_tensor = (
112
+ trim_level(sample_predictions).squeeze().detach().cpu()
113
+ )
114
+
115
+ if len(level_tensor.shape) == 1:
116
+ return SampleOutput.create(
117
+ level_tensor, sample_predictions_tensor, tokenizer, prompter
118
+ )
119
+
120
+ out = []
121
+ for _level_tensor, _sample_predictions_tensor in zip(
122
+ level_tensor, sample_predictions_tensor
123
+ ):
124
+ sample_output = SampleOutput.create(
125
+ _level_tensor, _sample_predictions_tensor, tokenizer, prompter
126
+ )
127
+ out.append(sample_output)
128
+ return out
129
+
130
+ def save(self, filename: str) -> str:
131
+ save_level(self.level, filename)
132
+
133
+ @classmethod
134
+ def load(cls, filename: str) -> SampleOutput:
135
+ level = load_level(filename)
136
+ return SampleOutput(level=level)
137
+
138
+ def play(self, game="mario", level_idx=None, dataset_path=None):
139
+ """
140
+ Play the level using the specified game engine.
141
+ game: "mario" (default) or "loderunner"
142
+ """
143
+ if game == "loderunner":
144
+ import tempfile, json
145
+ # Convert self.level (list of strings) to Lode Runner JSON format
146
+ scene = [[c for c in row] for row in self.level]
147
+ lr_json = [{
148
+ "scene": scene,
149
+ "caption": ""
150
+ }]
151
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
152
+ json.dump(lr_json, tmp)
153
+ tmp_path = tmp.name
154
+ import sys, os
155
+ #sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
156
+ from LodeRunner.loderunner import main
157
+ tmp_path = tmp_path if dataset_path is None else dataset_path
158
+ print(f"Playing Lode Runner level interactively -- {tmp_path}!")
159
+ main.play_lr_level(tmp_path, level_index=level_idx if level_idx is not None else 1)
160
+ else:
161
+ if self.use_snes_graphics:
162
+ simulator = CustomSimulator(level=self.level, jar_path="MarioEval.jar")
163
+ else:
164
+ simulator = CustomSimulator(level=self.level, jar_path="NESMarioEval.jar")
165
+ simulator.interactive()
166
+
167
+ def run_astar(self, render=True):
168
+ if self.use_snes_graphics:
169
+ simulator = CustomSimulator(level=self.level, jar_path="MarioEval.jar")
170
+ else:
171
+ simulator = CustomSimulator(level=self.level, jar_path="NESMarioEval.jar")
172
+ return simulator.astar(render)
173
+
174
+ class CustomSimulator:
175
+ """
176
+ The classic Mario simulator used by MarioGPT is generally,
177
+ better, but it doesn't return any information about
178
+ Mario's performance. The main point of this simulator
179
+ is that information about the performance of the agent
180
+ is printed to the console (though I still need a way
181
+ to caption and return that information)
182
+ """
183
+
184
+ def __init__(self, level, jar_path="MarioEval.jar"):
185
+ while len(level) > 15:
186
+ level.pop(0)
187
+ # For some reason, my older A* agent
188
+ # crashes on Mario levels with 16 rows or more
189
+
190
+ self.level = level
191
+ self.jar_path = jar_path
192
+
193
+ def interactive(self):
194
+ t = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
195
+ save_level(self.level, t.name)
196
+ print(f"Playing level interactively -- {t.name}!")
197
+ _ = subprocess.run(
198
+ ["java", "-jar", self.jar_path, "human", t.name, "human"],
199
+ stdout=subprocess.PIPE,
200
+ stderr=subprocess.PIPE,
201
+ )
202
+ t.close()
203
+ os.unlink(t.name)
204
+
205
+ def astar(self, render: bool = True):
206
+ t = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
207
+ save_level(self.level, t.name)
208
+ print(f"Running Astar agent on level! -- {t.name}")
209
+ render_str = "human" if render else "norender"
210
+ result = subprocess.run(
211
+ ["java", "-jar", self.jar_path, "astar", t.name, render_str],
212
+ stdout=subprocess.PIPE,
213
+ stderr=subprocess.PIPE,
214
+ )
215
+ t.close()
216
+ os.unlink(t.name)
217
+ # Combine stdout and stderr, decode to string, and return
218
+ output = result.stdout.decode("utf-8") + result.stderr.decode("utf-8")
219
+ return output
220
+
221
+ def save_level(level: List[str], filename: str):
222
+ concatenated = "\n".join(level)
223
+ with open(filename, "w") as f:
224
+ f.write(concatenated)
225
+ return filename
226
+
227
+ class GPTSampler:
228
+ def __init__(
229
+ self,
230
+ mario_lm: BaseMarioLM,
231
+ temperature: float = 2.0,
232
+ top_k: int = 16,
233
+ context_len: int = 700,
234
+ use_tqdm: bool = False,
235
+ use_argmax: bool = False,
236
+ ):
237
+ self.mario_lm = mario_lm
238
+ self.temperature = temperature
239
+ self.top_k = top_k
240
+ self.context_len = context_len
241
+ self.use_tqdm = use_tqdm
242
+ self.use_argmax = use_argmax
243
+ self.logits_processor = LogitsProcessorList()
244
+ self.logits_warper = LogitsProcessorList(
245
+ [
246
+ TopKLogitsWarper(top_k), # number of characters
247
+ TemperatureLogitsWarper(temperature),
248
+ ]
249
+ )
250
+
251
+ @property
252
+ def device(self) -> torch.device:
253
+ return self.mario_lm.device
254
+
255
+ def step(
256
+ self,
257
+ seed: torch.Tensor,
258
+ encoder_hidden_states: torch.Tensor,
259
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
260
+ with torch.no_grad():
261
+ attention_mask = torch.ones_like(seed).to(seed.device)
262
+ input_ids = seed
263
+ out = self.mario_lm.lm(
264
+ input_ids=input_ids,
265
+ attention_mask=attention_mask,
266
+ encoder_hidden_states=encoder_hidden_states,
267
+ token_type_ids=None,
268
+ )
269
+ logits = out.logits.detach()
270
+ if len(logits.shape) == 2:
271
+ logits = logits.view(1, 1, -1)
272
+ next_token_logits = logits[:, -1, :]
273
+
274
+ if self.use_argmax:
275
+ next_tokens = next_token_logits.argmax(-1)
276
+ else:
277
+ next_token_scores = self.logits_processor(input_ids, next_token_logits)
278
+ next_token_scores = self.logits_warper(input_ids, next_token_scores)
279
+ probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
280
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
281
+ return next_tokens, encoder_hidden_states
282
+
283
+ def sample(
284
+ self,
285
+ seed: Union[Optional[torch.Tensor], Optional[SampleOutput]] = None,
286
+ prompts: Optional[List[str]] = None,
287
+ num_steps: int = 1,
288
+ encoder_hidden_states: torch.Tensor = None,
289
+ return_tensor: bool = False,
290
+ ):
291
+ self.mario_lm.eval()
292
+ context_len = self.context_len - 28
293
+ with torch.no_grad():
294
+ if seed is None:
295
+ seed = self.mario_lm.generate_seed(1, batch_size=len(prompts)).to(
296
+ self.device
297
+ )
298
+ out_tensor = seed.to(self.device)
299
+ elif isinstance(seed, SampleOutput):
300
+ out_tensor = seed.level_tensor.to(self.device).squeeze()
301
+ else:
302
+ out_tensor = seed.to(self.device).squeeze()
303
+ if len(out_tensor.shape) < 2:
304
+ # if we pass in a single seed vector, then we repeat for each prompt
305
+ # Otherwise, we treat inputs as separate seed-prompt pairs
306
+ out_tensor = out_tensor.view(1, -1).repeat(len(prompts), 1)
307
+ if encoder_hidden_states is None:
308
+ if prompts is not None:
309
+ encoder_hidden_states = torch.stack(
310
+ [
311
+ self.mario_lm.prompter.output_hidden(prompt)
312
+ for prompt in prompts
313
+ ]
314
+ )
315
+ else:
316
+ encoder_hidden_states = torch.stack(
317
+ [
318
+ self.mario_lm.prompter(sample_prompt=True)[1]
319
+ for _ in range(seed.shape[0])
320
+ ]
321
+ )
322
+ encoder_hidden_states = encoder_hidden_states.to(
323
+ self.device
324
+ ) # b x 1 x hidden_dim
325
+ encoder_hidden_states = encoder_hidden_states.view(
326
+ out_tensor.shape[0], 1, -1
327
+ )
328
+ if not self.use_tqdm:
329
+ bar = np.arange(num_steps)
330
+ else:
331
+ bar = tqdm(np.arange(num_steps))
332
+ with torch.no_grad():
333
+ for i in bar:
334
+ inp = out_tensor * 1
335
+ if len(out_tensor.shape) > 0 and out_tensor.shape[-1] > context_len:
336
+ diff = inp.shape[-1] % 14 # height of mario level
337
+ ctx = context_len + diff
338
+ inp = inp[:, -ctx:] * 1
339
+ next_tokens, encoder_hidden_states = self.step(
340
+ inp,
341
+ encoder_hidden_states=encoder_hidden_states,
342
+ )
343
+ out_tensor = torch.cat(
344
+ [out_tensor, next_tokens.unsqueeze(-1)], dim=-1
345
+ )
346
+ if self.use_tqdm:
347
+ bar.set_description(
348
+ f"shape: {inp.shape}, {out_tensor.shape} first: {inp[0][0]}, last: {out_tensor[0][-1]}"
349
+ )
350
+ if self.use_tqdm:
351
+ bar.close()
352
+ sample_out = SampleOutput.from_level_predictions(
353
+ out_tensor,
354
+ out_tensor[:, -num_steps:],
355
+ self.mario_lm.tokenizer,
356
+ self.mario_lm.prompter,
357
+ )
358
+ self.mario_lm.train()
359
+ if return_tensor:
360
+ return sample_out, out_tensor
361
+ return sample_out
362
+
363
+ def __call__(self, *args, **kwargs):
364
+ return self.sample(*args, **kwargs)
365
+
366
+
367
+ class BertSampler:
368
+ def __init__(
369
+ self,
370
+ mario_lm: BaseMarioLM,
371
+ temperature: float = 2.0,
372
+ top_k: int = 16,
373
+ context_len: int = 448,
374
+ mask_proportion: float = 0.16,
375
+ ):
376
+ self.mario_lm = mario_lm
377
+ self.temperature = temperature
378
+ self.top_k = top_k
379
+ self.logits_processor = LogitsProcessorList()
380
+ self.logits_warper = LogitsProcessorList(
381
+ [
382
+ TopKLogitsWarper(top_k), # number of characters
383
+ TemperatureLogitsWarper(temperature),
384
+ ]
385
+ )
386
+ self.context_len = context_len
387
+ self.mask_proportion = mask_proportion
388
+ self.mask_portion = int(self.context_len * self.mask_proportion)
389
+ self.mask_portion = self.mask_portion - self.mask_portion % 14 + 14
390
+
391
+ @property
392
+ def device(self) -> torch.device:
393
+ return self.mario_lm.device
394
+
395
+ def get_context(self, input_ids, mask_indices):
396
+ start_idx = mask_indices[0]
397
+ end_idx = mask_indices[-1]
398
+
399
+ if input_ids.shape[-1] <= self.context_len:
400
+ clipped = input_ids.shape[-1] % 14
401
+ input_ids = input_ids[:clipped]
402
+
403
+ portion = (self.context_len - self.mask_portion) / 2
404
+
405
+ remainder = 0
406
+ left = start_idx - portion
407
+ if left < 0:
408
+ remainder = -1 * left
409
+
410
+ right = end_idx + portion + remainder
411
+
412
+ return input_ids[left:right]
413
+
414
+ def sample(
415
+ self,
416
+ seed: Union[torch.Tensor, SampleOutput],
417
+ mask: torch.Tensor,
418
+ return_tensor: bool = False,
419
+ ):
420
+ self.mario_lm.eval()
421
+ mask_indices = mask.nonzero()
422
+ input_ids = seed
423
+ if isinstance(seed, SampleOutput):
424
+ input_ids = seed.level_tensor.to(self.device).squeeze()
425
+
426
+ input_id_list = []
427
+ for i in range(input_ids.shape[0]):
428
+ input_id = input_ids[i]
429
+ mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
430
+ input_id = self.get_context(input_id, mask_index)
431
+ input_id_list.append(input_id)
432
+ input_ids = torch.stack(input_ids, dim=0).to(self.device)
433
+
434
+ attention_mask = torch.ones_like(input_ids).to(seed.device)
435
+
436
+ if len(input_ids.shape) < 2:
437
+ # if we pass in a single seed vector, then we repeat for each prompt
438
+ # Otherwise, we treat inputs as separate seed-prompt pairs
439
+ input_ids = input_ids.view(1, -1)
440
+
441
+ out = self.mario_lm.lm(
442
+ input_ids=input_ids,
443
+ attention_mask=attention_mask,
444
+ token_type_ids=None,
445
+ )
446
+ logits = out.logits.detach()
447
+ if len(logits.shape) == 2:
448
+ logits = logits.view(1, 1, -1)
449
+
450
+ if self.use_argmax:
451
+ tokens = logits.argmax(-1)
452
+ else:
453
+ tokens_scores = self.logits_processor(input_ids, tokens)
454
+ tokens_scores = self.logits_warper(input_ids, tokens_scores)
455
+ probs = torch.nn.functional.softmax(tokens_scores, dim=-1)
456
+ tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
457
+
458
+ out = input_ids.detach()
459
+
460
+ for i in range(input_ids.shape[0]):
461
+ mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
462
+ out[i, mask_index] = tokens[i, mask_index].detach()
463
+
464
+ sample_out = SampleOutput.from_level_predictions(
465
+ out,
466
+ tokens,
467
+ self.mario_lm.tokenizer,
468
+ self.mario_lm.prompter,
469
+ )
470
+ self.mario_lm.train()
471
+ if return_tensor:
472
+ return sample_out, tokens
473
+ return sample_out
sentence_transformers_helper.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ #Mean Pooling - Take average of all tokens
6
+ def mean_pooling(model_output, attention_mask):
7
+ token_embeddings = model_output.last_hidden_state
8
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
9
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
10
+
11
+
12
+ #Encode text
13
+ def encode(texts, tokenizer, model, device='cpu'):
14
+ # Tokenize sentences
15
+ encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
16
+ encoded_input.to(device)
17
+
18
+ # Compute token embeddings
19
+ with torch.no_grad():
20
+ model_output = model(**encoded_input, return_dict=True)
21
+
22
+ # Perform pooling
23
+ embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
24
+
25
+ # Normalize embeddings
26
+ embeddings = F.normalize(embeddings, p=2, dim=1)
27
+
28
+ embeddings = embeddings.to(device)
29
+
30
+ return embeddings
31
+
32
+ # Get embeddings for a batch of captions and optional negative captions
33
+ def get_embeddings(batch_size, tokenizer, model, captions=None, neg_captions=None, device='cpu'):
34
+ embeddings = encode([""]*batch_size, tokenizer, model, device)
35
+
36
+ if captions is not None:
37
+ caption_embeddings = encode(captions, tokenizer, model, device)
38
+ embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
39
+
40
+ if neg_captions is not None:
41
+ neg_embeddings = encode(neg_captions, tokenizer, model, device)
42
+ embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
43
+
44
+
45
+ embeddings = embeddings.unsqueeze(1)
46
+
47
+ return embeddings
48
+
49
+
50
+
51
+
52
+ def get_embeddings_split(batch_size, tokenizer, model, captions=None, neg_captions=None, device='cpu', max_length=20):
53
+
54
+ padding_length = max(max([s.count(".") for s in captions]) if captions else 1,
55
+ max([s.count(".") for s in neg_captions]) if neg_captions else 1)
56
+ if (padding_length>max_length):
57
+ raise ValueError(f"Token sequence length {padding_length} exceeds specified length {max_length}.")
58
+
59
+
60
+ empty_split = split_sentences([""] * batch_size, padding_length)
61
+ embeddings = get_embeddings_from_split(empty_split, tokenizer, model, device)
62
+
63
+ if(captions is not None):
64
+ captions_split = split_sentences(captions, padding_length)
65
+ caption_embeddings = get_embeddings_from_split(captions_split, tokenizer, model, device)
66
+ embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
67
+
68
+ if(neg_captions is not None):
69
+ neg_split = split_sentences(neg_captions, padding_length)
70
+ neg_embeddings = get_embeddings_from_split(neg_split, tokenizer, model, device)
71
+ embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
72
+
73
+
74
+ #We don't need to unsqueeze this, we have an array of (batch_size, padding_length, encoding_size) already
75
+
76
+ return embeddings.to(device)
77
+
78
+
79
+ #This method takes a caption batch in list form, and outputs a 2d list where every caption has been split by period
80
+ def split_sentences(caption_array, padding_length=20):
81
+ split_caption_array = []
82
+
83
+ #Padding happens here
84
+ for caption in caption_array:
85
+ split_caption = [s.strip() for s in caption.split(".") if s.strip()]
86
+ #This is the token padding, we just use an empty string
87
+ split_caption += [""] * (padding_length - len(split_caption))
88
+ split_caption_array.append(split_caption)
89
+
90
+ return split_caption_array
91
+
92
+
93
+ #Expects all split vectors to be the same length
94
+ def get_embeddings_from_split(caption_batch, tokenizer, model, device='cpu'):
95
+ all_caption_encodings = []
96
+ for caption_sequence in caption_batch:
97
+ #Encode the sequence of split captions as if it was a batch, should now be a [maxlength, embeddingsize] tensor
98
+ caption_sequence = encode(caption_sequence, tokenizer, model, device)
99
+
100
+ #We don't reshape this to avoid having to unsqueeze it later
101
+ all_caption_encodings.append(caption_sequence)
102
+
103
+ all_caption_encodings = torch.stack(all_caption_encodings, dim=0)
104
+ return all_caption_encodings
105
+
106
+
107
+
108
+ if __name__ == "__main__":
109
+ cap = split_sentences(["Hello. My name is George. How. Are you doing. Today?", "I am doing. Just fine. Thanks."])
110
+ model_url = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
111
+ device = 'cuda'
112
+ tokenizer = AutoTokenizer.from_pretrained(model_url)
113
+ model = AutoModel.from_pretrained(model_url, trust_remote_code=True).to(device)
114
+ get_embeddings_from_split(cap, tokenizer, model, device)
text_diffusion_pipeline.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import NamedTuple, Optional
4
+ import os
5
+ from diffusers import DDPMPipeline, UNet2DConditionModel, DDPMScheduler
6
+ import json
7
+ # Running the main at the end of this requires messing with this import
8
+ from text_model import TransformerModel
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from transformers import AutoTokenizer, AutoModel
12
+ import common_settings as common_settings
13
+ import sentence_transformers_helper as st_helper
14
+ import text_model as text_model
15
+ from general_training_helper import get_scene_from_embeddings
16
+
17
+ class PipelineOutput(NamedTuple):
18
+ images: torch.Tensor
19
+
20
+
21
+
22
+ # Create a custom pipeline for text-conditional generation
23
+ class TextConditionalDDPMPipeline(DDPMPipeline):
24
+ def __init__(self, unet, scheduler, text_encoder=None, tokenizer=None, supports_pretrained_split=False, block_embeddings=None):
25
+ super().__init__(unet=unet, scheduler=scheduler)
26
+ self.text_encoder = text_encoder
27
+ self.tokenizer = tokenizer
28
+ self.supports_negative_prompt = hasattr(unet, 'negative_prompt_support') and unet.negative_prompt_support
29
+ self.supports_pretrained_split = supports_pretrained_split
30
+ self.block_embeddings = block_embeddings
31
+
32
+ if self.tokenizer is None and self.text_encoder is not None:
33
+ # Use the tokenizer from the text encoder if not provided
34
+ self.tokenizer = self.text_encoder.tokenizer
35
+
36
+ # Register the text_encoder so that .to(), .cpu(), .cuda(), etc. work correctly
37
+ self.register_modules(
38
+ unet=unet,
39
+ scheduler=scheduler,
40
+ text_encoder=self.text_encoder,
41
+ tokenizer=self.tokenizer,
42
+ )
43
+
44
+ # Override the to() method to ensure text_encoder is moved to the correct device
45
+ def to(self, device=None, dtype=None):
46
+ # Call the parent's to() method first
47
+ pipeline = super().to(device, dtype)
48
+
49
+ # Additionally move the text_encoder to the device
50
+ if self.text_encoder is not None:
51
+ self.text_encoder.to(device)
52
+
53
+ return pipeline
54
+
55
+ def save_pretrained(self, save_directory):
56
+ os.makedirs(save_directory, exist_ok=True)
57
+ super().save_pretrained(save_directory) # saves UNet and scheduler
58
+
59
+ # Save block_embeddings tensor if it exists
60
+ if self.block_embeddings is not None:
61
+ torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))
62
+
63
+ # Save supports_negative_prompt and supports_pretrained_split flags
64
+ with open(os.path.join(save_directory, "pipeline_config.json"), "w") as f:
65
+ json.dump({
66
+ "supports_negative_prompt": self.supports_negative_prompt,
67
+ "supports_pretrained_split": self.supports_pretrained_split,
68
+ "text_encoder_type": type(self.text_encoder).__name__
69
+ }, f)
70
+
71
+
72
+ #Text encoder/tokenizer saving is different depending on if we're using a larger pretrained model
73
+ if isinstance(self.text_encoder, TransformerModel):
74
+ # Save custom text encoder
75
+ if self.text_encoder is not None:
76
+ self.text_encoder.save_pretrained(os.path.join(save_directory, "text_encoder"))
77
+ else:
78
+ #Save pretrained tokenizer by name, so we can load from huggingface instead of saving a giant local model
79
+ text_encoder_info = {
80
+ "text_encoder_name": self.text_encoder.config.name_or_path,
81
+ "tokenizer_name": self.tokenizer.name_or_path,
82
+ }
83
+
84
+ text_encoder_directory = os.path.join(save_directory, "text_encoder")
85
+ os.makedirs(text_encoder_directory, exist_ok=True)
86
+
87
+ with open(os.path.join(text_encoder_directory, "loading_info.json"), "w") as f:
88
+ json.dump(text_encoder_info, f)
89
+
90
+
91
+
92
+ @classmethod
93
+ def from_pretrained(cls, pretrained_model_path, **kwargs):
94
+ #from diffusers.utils import load_config, load_state_dict
95
+ # Load model_index.json
96
+ #model_index = load_config(pretrained_model_path)
97
+
98
+ # Load components manually
99
+ unet_path = os.path.join(pretrained_model_path, "unet")
100
+ unet = UNet2DConditionModel.from_pretrained(unet_path)
101
+
102
+ scheduler_path = os.path.join(pretrained_model_path, "scheduler")
103
+ # Have heard that DDIMScheduler might be faster for inference, though not necessarily better
104
+ scheduler = DDPMScheduler.from_pretrained(scheduler_path)
105
+
106
+ tokenizer = None
107
+ text_encoder_path = os.path.join(pretrained_model_path, "text_encoder")
108
+
109
+ if os.path.exists(text_encoder_path):
110
+ #Test for the new saving system, where we save a simple config file
111
+ if os.path.exists(os.path.join(text_encoder_path, "loading_info.json")):
112
+ with open(os.path.join(text_encoder_path, "loading_info.json"), "r") as f:
113
+ encoder_config = json.load(f)
114
+
115
+ text_encoder = AutoModel.from_pretrained(encoder_config['text_encoder_name'], trust_remote_code=True)
116
+ tokenizer = AutoTokenizer.from_pretrained(encoder_config['tokenizer_name'])
117
+
118
+ #Legacy loading system, loads models directly if the whole thing is saved in the directory
119
+ else:
120
+ try:
121
+ text_encoder = AutoModel.from_pretrained(text_encoder_path, local_files_only=True, trust_remote_code=True)
122
+ tokenizer = AutoTokenizer.from_pretrained(text_encoder_path, local_files_only=True)
123
+ except (ValueError, KeyError):
124
+ text_encoder = TransformerModel.from_pretrained(text_encoder_path)
125
+ tokenizer = text_encoder.tokenizer
126
+ else:
127
+ text_encoder = None
128
+
129
+ # Instantiate your pipeline
130
+ pipeline = cls(
131
+ unet=unet,
132
+ scheduler=scheduler,
133
+ text_encoder=text_encoder,
134
+ tokenizer=tokenizer,
135
+ **kwargs,
136
+ )
137
+
138
+ #Loads block embeddings if present
139
+ block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
140
+ if os.path.exists(block_embeds_path):
141
+ pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
142
+ else:
143
+ pipeline.block_embeddings = None
144
+
145
+
146
+ # Load supports_negative_prompt flag if present
147
+ config_path = os.path.join(pretrained_model_path, "pipeline_config.json")
148
+ if os.path.exists(config_path):
149
+ with open(config_path, "r") as f:
150
+ config = json.load(f)
151
+ pipeline.supports_negative_prompt = config.get("supports_negative_prompt", False)
152
+ pipeline.supports_pretrained_split = config.get("supports_pretrained_split", False)
153
+ return pipeline
154
+
155
+ # --- Handle batching for captions ---
156
+ def _prepare_text_batch(self, text: Optional[str | list[str]], batch_size: int, name: str) -> Optional[list[str]]:
157
+ if text is None:
158
+ return None
159
+ if isinstance(text, str):
160
+ return [text] * batch_size
161
+ if isinstance(text, list):
162
+ if len(text) == 1:
163
+ return text * batch_size
164
+ if len(text) != batch_size:
165
+ raise ValueError(f"{name} list length {len(text)} does not match batch_size {batch_size}")
166
+ return text
167
+ raise ValueError(f"{name} must be a string or list of strings")
168
+
169
+ def _prepare_initial_sample(self,
170
+ raw_latent_sample: Optional[torch.Tensor],
171
+ input_scene: Optional[torch.Tensor],
172
+ batch_size: int, height: int, width: int,
173
+ generator: Optional[torch.Generator]) -> torch.Tensor:
174
+ """Prepare the initial sample for diffusion."""
175
+
176
+ sample_shape = (batch_size, self.unet.config.in_channels, height, width)
177
+
178
+ if raw_latent_sample is not None:
179
+ if input_scene is not None:
180
+ raise ValueError("Cannot provide both raw_latent_sample and input_scene")
181
+ sample = raw_latent_sample.to(self.device)
182
+ if sample.shape[1] != sample_shape[1]:
183
+ raise ValueError(f"Wrong number of channels in raw_latent_sample: Expected {self.unet.config.in_channels} but got {sample.shape[1]}")
184
+ if sample.shape[0] == 1 and batch_size > 1:
185
+ sample = sample.repeat(batch_size, 1, 1, 1)
186
+ elif sample.shape[0] != batch_size:
187
+ raise ValueError(f"raw_latent_sample batch size {sample.shape[0]} does not match batch_size {batch_size}")
188
+ elif input_scene is not None:
189
+ # input_scene can be (H, W) or (batch_size, H, W)
190
+ scene_tensor = torch.tensor(input_scene, dtype=torch.long, device=self.device)
191
+ if scene_tensor.dim() == 2:
192
+ # (H, W) -> repeat for batch
193
+ scene_tensor = scene_tensor.unsqueeze(0).repeat(batch_size, 1, 1)
194
+ elif scene_tensor.shape[0] == 1 and batch_size > 1:
195
+ scene_tensor = scene_tensor.repeat(batch_size, 1, 1)
196
+ elif scene_tensor.shape[0] != batch_size:
197
+ raise ValueError(f"input_scene batch size {scene_tensor.shape[0]} does not match batch_size {batch_size}")
198
+ # One-hot encode: (batch, H, W, C)
199
+ one_hot = F.one_hot(scene_tensor, num_classes=self.unet.config.in_channels).float()
200
+ # (batch, H, W, C) -> (batch, C, H, W)
201
+ sample = one_hot.permute(0, 3, 1, 2)
202
+ else:
203
+ # Start from random noise
204
+ sample = torch.randn(sample_shape, generator=generator, device=self.device)
205
+
206
+ return sample
207
+
208
+ def __call__(
209
+ self,
210
+ caption: Optional[str | list[str]] = None,
211
+ negative_prompt: Optional[str | list[str]] = None,
212
+ generator: Optional[torch.Generator] = None,
213
+ num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,
214
+ guidance_scale: float = common_settings.GUIDANCE_SCALE,
215
+ height: int = common_settings.MARIO_HEIGHT,
216
+ width: int = common_settings.MARIO_WIDTH,
217
+ raw_latent_sample: Optional[torch.FloatTensor] = None,
218
+ input_scene: Optional[torch.Tensor] = None,
219
+ output_type: str = "tensor",
220
+ batch_size: int = 1,
221
+ show_progress_bar: bool = True,
222
+ ) -> PipelineOutput:
223
+ """Generate a batch of images based on text input using the diffusion model.
224
+
225
+ Args:
226
+ caption: Text description(s) of the desired output. Can be a string or list of strings.
227
+ negative_prompt: Text description(s) of what should not appear in the output. String or list.
228
+ generator: Random number generator for reproducibility.
229
+ num_inference_steps: Number of denoising steps (more = higher quality, slower).
230
+ guidance_scale: How strongly the generation follows the text prompt (higher = stronger).
231
+ height: Height of generated image in tiles.
232
+ width: Width of generated image in tiles.
233
+ raw_latent_sample: Optional starting point for diffusion instead of random noise.
234
+ Must have correct number of channels matching the UNet.
235
+ input_scene: Optional 2D or 3D int tensor where each value corresponds to a tile type.
236
+ Will be converted to one-hot encoding as starting point.
237
+ output_type: Currently only "tensor" is supported.
238
+ batch_size: Number of samples to generate in parallel.
239
+
240
+ Returns:
241
+ PipelineOutput containing the generated image tensor (batch_size, ...).
242
+ """
243
+
244
+ # I would like to simplify the code to this, but the AI suggestion didn't work, and
245
+ # I did not feel good just pasting it all in. Will need to tackle it bit by bit.
246
+
247
+ # if caption is not None and self.text_encoder is None:
248
+ # raise ValueError("Text encoder required for conditional generation")
249
+
250
+ # self.unet.eval()
251
+ # if self.text_encoder is not None:
252
+ # self.text_encoder.to(self.device)
253
+ # self.text_encoder.eval()
254
+ #
255
+ # with torch.no_grad():
256
+ # # Process text inputs
257
+ # captions = self.prepare_text_batch(caption, batch_size, "caption")
258
+ # negatives = self.prepare_text_batch(negative_prompt, batch_size, "negative_prompt")
259
+
260
+ # # Get embeddings
261
+ # text_embeddings = self.prepare_embeddings(captions, negatives, batch_size)
262
+ #
263
+ # # Set up initial latent state
264
+ # sample = self.prepare_initial_sample(raw_latent_sample, input_scene,
265
+ # batch_size, height, width, generator)
266
+
267
+ # # Run diffusion process
268
+ # sample = self.run_diffusion(sample, text_embeddings, num_inference_steps,
269
+ # guidance_scale, generator, show_progress_bar,
270
+ # has_caption=caption is not None,
271
+ # has_negative=negative_prompt is not None)
272
+
273
+ # # Format output
274
+ # if output_type == "tensor":
275
+ # sample = F.softmax(sample, dim=1)
276
+ # else:
277
+ # raise ValueError(f"Unsupported output type: {output_type}")
278
+
279
+ # return PipelineOutput(images=sample)
280
+
281
+ # Validate text encoder if we need it
282
+ if caption is not None and self.text_encoder is None:
283
+ raise ValueError("Text encoder is required for conditional generation")
284
+
285
+ self.unet.eval()
286
+ if self.text_encoder is not None:
287
+ self.text_encoder.to(self.device)
288
+ self.text_encoder.eval()
289
+
290
+ with torch.no_grad():
291
+ captions = self._prepare_text_batch(caption, batch_size, "caption")
292
+ negatives = self._prepare_text_batch(negative_prompt, batch_size, "negative_prompt")
293
+
294
+ # --- Prepare text embeddings ---
295
+ if(isinstance(self.text_encoder, TransformerModel)):
296
+ text_embeddings = text_model.get_embeddings(batch_size=batch_size,
297
+ tokenizer=self.text_encoder.tokenizer,
298
+ text_encoder=self.text_encoder,
299
+ captions=captions,
300
+ neg_captions=negatives,
301
+ device=self.device)
302
+ else: #Case for the pre-trained text encoder
303
+ if(self.supports_pretrained_split): #If we have a split flag incorporated
304
+ text_embeddings = st_helper.get_embeddings_split(batch_size = batch_size,
305
+ tokenizer=self.tokenizer,
306
+ model=self.text_encoder,
307
+ captions=captions,
308
+ neg_captions=negatives,
309
+ device=self.device)
310
+ else:
311
+ text_embeddings = st_helper.get_embeddings(batch_size = batch_size,
312
+ tokenizer=self.tokenizer,
313
+ model=self.text_encoder,
314
+ captions=captions,
315
+ neg_captions=negatives,
316
+ device=self.device)
317
+
318
+
319
+ # --- Set up initial latent state ---
320
+ sample = self._prepare_initial_sample(raw_latent_sample, input_scene,
321
+ batch_size, height, width, generator)
322
+
323
+ # --- Set up diffusion process ---
324
+ self.scheduler.set_timesteps(num_inference_steps)
325
+
326
+ # Denoising loop
327
+ iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
328
+ for t in iterator:
329
+ # Handle conditional generation
330
+ if captions is not None:
331
+ if negatives is not None:
332
+ # Three copies for negative prompt guidance
333
+ model_input = torch.cat([sample, sample, sample], dim=0)
334
+ else:
335
+ # Two copies for standard classifier-free guidance
336
+ model_input = torch.cat([sample, sample], dim=0)
337
+ else:
338
+ model_input = sample
339
+
340
+ # Predict noise residual
341
+ model_kwargs = {"encoder_hidden_states": text_embeddings}
342
+ noise_pred = self.unet(model_input, t, **model_kwargs).sample
343
+
344
+ # Apply guidance
345
+ if captions is not None:
346
+ if negatives is not None:
347
+ # Split predictions for negative, unconditional, and text-conditional
348
+ noise_pred_neg, noise_pred_uncond, noise_pred_text = noise_pred.chunk(3)
349
+ noise_pred_guided = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
350
+ noise_pred = noise_pred_guided - guidance_scale * (noise_pred_neg - noise_pred_uncond)
351
+ else:
352
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
353
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
354
+
355
+ # Compute previous sample: x_{t-1} = scheduler(x_t, noise_pred)
356
+ sample = self.scheduler.step(noise_pred, t, sample, generator=generator).prev_sample
357
+
358
+ # Convert to output format
359
+ if output_type == "tensor":
360
+ if self.block_embeddings is not None:
361
+ sample = get_scene_from_embeddings(sample, self.block_embeddings)
362
+ else:
363
+ # Apply softmax to get probabilities for each tile type
364
+ sample = F.softmax(sample, dim=1)
365
+ sample = sample.detach().cpu()
366
+ else:
367
+ raise ValueError(f"Unsupported output type: {output_type}")
368
+
369
+ return PipelineOutput(images=sample)
370
+
371
+ def print_unet_architecture(self):
372
+ """Prints the architecture of the UNet model."""
373
+ print(self.unet)
374
+
375
+ def print_text_encoder_architecture(self):
376
+ """Prints the architecture of the text encoder model, if it exists."""
377
+ if self.text_encoder is not None:
378
+ print(self.text_encoder)
379
+ else:
380
+ print("No text encoder is set.")
381
+
382
+ def save_unet_architecture_pdf(self, height, width, filename="unet_architecture", batch_size=1, device=None):
383
+ """
384
+ Have to separately install torchview for this to work
385
+
386
+ Saves a visualization of the UNet architecture as a PDF using torchview.
387
+ Args:
388
+ height: Height of the dummy input.
389
+ width: Width of the dummy input.
390
+ filename: Output PDF filename.
391
+ batch_size: Batch size for dummy input.
392
+ device: Device to run the dummy input on (defaults to pipeline device).
393
+ """
394
+ from torchview import draw_graph
395
+ import graphviz
396
+
397
+ if device is None:
398
+ device = self.device if hasattr(self, 'device') else 'cpu'
399
+ in_channels = self.unet.config.in_channels if hasattr(self.unet, 'config') else 1
400
+ sample_shape = tuple([batch_size, in_channels, height, width])
401
+
402
+ dummy_x = torch.randn(size=sample_shape, device=device)
403
+ dummy_t = torch.tensor([0] * batch_size, dtype=torch.long, device=device)
404
+
405
+ # Prepare dummy text embedding (match what your UNet expects)
406
+ if hasattr(self.unet, 'config') and hasattr(self.unet.config, 'cross_attention_dim'):
407
+ cross_attention_dim = self.unet.config.cross_attention_dim
408
+ else:
409
+ cross_attention_dim = 128 # fallback
410
+ encoder_hidden_states = torch.randn(batch_size, 1, cross_attention_dim, device=device)
411
+
412
+ self.unet.eval()
413
+ inputs = (dummy_x, dummy_t, encoder_hidden_states)
414
+ #self.unet.down_blocks = self.unet.down_blocks[:2]
415
+
416
+ graph = draw_graph(
417
+ model=self.unet,
418
+ input_data=inputs,
419
+ expand_nested=False,
420
+ #enable_output_shape=True,
421
+ #roll_out="nested",
422
+ depth=1
423
+ )
424
+ #graph.visual_graph.engine = "neato"
425
+ graph.visual_graph.attr(#rankdir="LR",
426
+ nodesep="0.1", # decrease space between nodes in the same rank (default ~0.25)
427
+ ranksep="0.2", # decrease space between ranks (default ~0.5)
428
+ concentrate="true" # merge edges between nodes in the same rank
429
+ )
430
+ graph.visual_graph.node_attr.update(
431
+ shape="rectangle",
432
+ width="1.5", # narrow width
433
+ height="0.5" # taller height to make vertical rectangles
434
+ #fixedsize="true"
435
+ )
436
+
437
+ graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
438
+ graph.visual_graph.save('unet_architecture.dot')
439
+
440
+ # Save the graph to a PDF file
441
+ print(f"UNet architecture saved to {filename}")
442
+
text_model.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from xml.parsers.expat import model
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+ import os
7
+ import json
8
+ from safetensors.torch import save_file, load_file
9
+ from tokenizer import Tokenizer
10
+
11
+ def get_embeddings(batch_size, tokenizer, text_encoder, captions=None, neg_captions=None, device='cpu'):
12
+ max_length = text_encoder.max_seq_length
13
+ empty_ids = encode_token_captions([""] * batch_size, tokenizer, max_length, device=device)
14
+ embeddings = text_encoder.get_embeddings(empty_ids)
15
+
16
+ if(captions is not None):
17
+ caption_ids = encode_token_captions(captions, tokenizer, max_length, device=device)
18
+ caption_embeddings = text_encoder.get_embeddings(caption_ids)
19
+ embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
20
+
21
+ if(neg_captions is not None):
22
+ neg_ids = encode_token_captions(neg_captions, tokenizer, max_length, device=device)
23
+ neg_embeddings = text_encoder.get_embeddings(neg_ids)
24
+ embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
25
+
26
+ return embeddings.to(device)
27
+
28
+ def encode_token_captions(captions, tokenizer, max_length, device='cpu'):
29
+ caption_ids = []
30
+ for caption in captions:
31
+ tokens = tokenizer.encode(caption)
32
+ caption_tokens = tokenizer.pad_sequence(tokens, max_length)
33
+ caption_ids.append(torch.tensor(caption_tokens, dtype=torch.long).unsqueeze(0))
34
+ return torch.cat(caption_ids, dim=0).to(device)
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+ # Transformer model for MLM training
45
+
46
+ class TransformerModel(nn.Module):
47
+ def __init__(self, vocab_size, embedding_dim, hidden_dim, tokenizer=None, num_heads=8, num_layers=4, max_seq_length=100):
48
+ super().__init__()
49
+ self.embedding_dim = embedding_dim
50
+ self.vocab_size = vocab_size
51
+ self.hidden_dim = hidden_dim
52
+ self.num_heads = num_heads
53
+ self.num_layers = num_layers
54
+ self.max_seq_length = max_seq_length
55
+
56
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
57
+ self.positional_encoding = self.create_positional_encoding(max_seq_length, embedding_dim)
58
+
59
+ encoder_layers = nn.TransformerEncoderLayer(
60
+ d_model=embedding_dim,
61
+ nhead=num_heads,
62
+ dim_feedforward=hidden_dim,
63
+ batch_first=True
64
+ )
65
+ self.transformer = nn.TransformerEncoder(encoder_layers, num_layers)
66
+ self.fc = nn.Linear(embedding_dim, vocab_size)
67
+
68
+ self.tokenizer = tokenizer
69
+
70
+ def create_positional_encoding(self, max_seq_length, embedding_dim):
71
+ # The implementation uses a sinusoidal positional encoding, which creates a unique pattern for each position in the sequence.
72
+ # The frequencies create unique values, the sin/cos bounds values
73
+ position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
74
+ # Creates a set of divisors that create different frequencies
75
+ div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
76
+ pe = torch.zeros(max_seq_length, embedding_dim)
77
+ # Even dimensions use sin, odd dimensions use cos
78
+ pe[:, 0::2] = torch.sin(position * div_term)
79
+ pe[:, 1::2] = torch.cos(position * div_term)
80
+ return pe.unsqueeze(0)
81
+
82
+ def get_embeddings(self, x):
83
+ """ This gets the actual latent embedding vectors """
84
+ # Ensure positional encoding is on the same device as input
85
+ pe = self.positional_encoding[:, :x.size(1), :].to(x.device)
86
+ # Embed input and add positional encoding
87
+ embedded = self.embedding(x) + pe
88
+ return self.transformer(embedded)
89
+
90
+ def forward(self, x):
91
+ """ This gets the token within the vocabulary """
92
+ transformer_out = self.get_embeddings(x)
93
+ # Project to vocabulary size
94
+ return self.fc(transformer_out)
95
+
96
+ def save_pretrained(self, save_directory):
97
+ os.makedirs(save_directory, exist_ok=True)
98
+
99
+ config = {
100
+ "vocab_size": self.vocab_size,
101
+ "embedding_dim": self.embedding_dim,
102
+ "hidden_dim": self.hidden_dim,
103
+ "num_heads": self.num_heads,
104
+ "num_layers": self.num_layers,
105
+ "max_seq_length": self.max_seq_length,
106
+ }
107
+ with open(os.path.join(save_directory, "config.json"), "w") as f:
108
+ json.dump(config, f)
109
+
110
+ # Save model weights
111
+ save_file(self.state_dict(), os.path.join(save_directory, "model.safetensors"))
112
+
113
+ # Save tokenizer if present
114
+ if self.tokenizer is not None:
115
+ self.tokenizer.save(os.path.join(save_directory, "tokenizer.pkl"))
116
+
117
+ @classmethod
118
+ def from_pretrained(cls, load_directory):
119
+ with open(os.path.join(load_directory, "config.json")) as f:
120
+ config = json.load(f)
121
+
122
+ model = cls(**config)
123
+
124
+ # Load weights
125
+ state_dict = load_file(os.path.join(load_directory, "model.safetensors"))
126
+ model.load_state_dict(state_dict)
127
+
128
+ # Load tokenizer if available
129
+ tokenizer_path = os.path.join(load_directory, "tokenizer.pkl")
130
+ if os.path.exists(tokenizer_path):
131
+ tokenizer = Tokenizer()
132
+ tokenizer.load(tokenizer_path)
133
+ model.tokenizer = tokenizer
134
+
135
+ return model
136
+
137
+ def print_architecture(self, inputs=None):
138
+ parser = argparse.ArgumentParser()
139
+ parser.add_argument("--model_path", type=str, required=True, help="Path to trained transformer model")
140
+ parser.add_argument("--json", type=str, default="SMB1_LevelsAndCaptions-regular-test.json", help="Path to dataset json file")
141
+ parser.add_argument("--num_samples", type=int, default=10, help="Number of captions to evaluate")
142
+ parser.add_argument("--mask_prob", type=float, default=0.15, help="Probability of masking each token")
143
+
144
+ parser.add_argument("--compare_checkpoints", action="store_true", default=False, help="Run comparison across all model checkpoints")
145
+ args = parser.parse_args()
146
+
147
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148
+ model = TransformerModel.from_pretrained(args.model_path).to(device)
149
+ print(f"Loaded model from {args.model_path}")
150
+
151
+ import os
152
+ import re
153
+ import json
154
+ import matplotlib.pyplot as plt
155
+ from torchview import draw_graph
156
+ import graphviz
157
+
158
+ graph = draw_graph(
159
+ model=model,
160
+ input_data=inputs,
161
+ expand_nested=False,
162
+ #enable_output_shape=True,
163
+ #roll_out="nested",
164
+ depth=1
165
+ )
166
+
167
+ # Save plot
168
+ filename = 'mlm_architecture'
169
+ graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
170
+ #graph.visual_graph.save('unet_architecture.dot')
171
+
172
+ def save_architecture_pdf(self, filename="transformer_architecture.pdf", input_length=32):
173
+ """Save a visualization of the model architecture as a PDF using torchview."""
174
+ try:
175
+ from torchview import draw_graph
176
+ except ImportError:
177
+ raise ImportError("torchview is required for model visualization. Install with 'pip install torchview'.")
178
+ import torch
179
+ import os
180
+ # Create a dummy input of the correct type for the model
181
+ captions = ["full floor. two coins. one pipe.", "floor with two gaps. one cannon. many enemies."]
182
+ tensor = encode_token_captions(captions, self.tokenizer, self.max_seq_length, device=next(self.parameters()).device)
183
+ input_length = tensor.size(1) if tensor.dim() > 1 else self.max_seq_length
184
+
185
+ num_tokens_list = [len(self.tokenizer.encode(c)) for c in captions]
186
+ input_length = max(num_tokens_list) if num_tokens_list else input_length
187
+ dummy_input = torch.zeros((1, input_length), dtype=torch.long, device=next(self.parameters()).device)
188
+
189
+ # Draw the graph and save as PNG
190
+ graph = draw_graph(self, input_data=dummy_input, expand_nested=True, save_graph=True, filename=filename.replace('.pdf',''), directory=".", depth=2)
191
+ png_file = filename.replace('.pdf', '.png')
192
+ # Convert PNG to PDF
193
+ if os.path.exists(png_file):
194
+ try:
195
+ from PIL import Image
196
+ im = Image.open(png_file)
197
+ im.save(filename, "PDF", resolution=100.0)
198
+ print(f"Saved architecture PDF to {filename}")
199
+ # Optionally, remove the PNG file
200
+ os.remove(png_file)
201
+ except ImportError:
202
+ print(f"PIL not installed. Architecture saved as PNG: {png_file}")
203
+ except Exception as e:
204
+ print(f"Could not convert PNG to PDF: {e}")
205
+ else:
206
+ print(f"Could not find PNG file to convert: {png_file}")
text_to_level_diffusion.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from interactive_generation import InteractiveGeneration
2
+ import torch
3
+ from level_dataset import visualize_samples, convert_to_level_format, positive_negative_caption_split
4
+ from caption_match import compare_captions, process_scene_segments
5
+ from create_ascii_captions import assign_caption
6
+ from util import extract_tileset
7
+ from sampler import scene_to_ascii
8
+ import argparse
9
+ import common_settings as common_settings
10
+ from sampler import SampleOutput
11
+ from pipeline_loader import get_pipeline
12
+
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser(description="Generate levels using a trained diffusion model")
16
+ # Model and generation parameters
17
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the trained diffusion model")
18
+ parser.add_argument("--tileset", default='..\TheVGLC\Super Mario Bros\smb.json', help="Descriptions of individual tile types")
19
+ #parser.add_argument("--describe_locations", action="store_true", default=False, help="Include location descriptions in the captions")
20
+ parser.add_argument("--describe_absence", action="store_true", default=False, help="Indicate when there are no occurrences of an item or structure")
21
+ parser.add_argument("--automatic_negative_captions", action="store_true", default=False, help="Automatically create negative captions for prompts so the user doesn't have to")
22
+
23
+
24
+ parser.add_argument(
25
+ "--game",
26
+ type=str,
27
+ default="Mario",
28
+ choices=["Mario", "LR"],
29
+ help="Which game to create a model for (affects sample style and tile count)"
30
+ )
31
+
32
+ return parser.parse_args()
33
+
34
+ class InteractiveLevelGeneration(InteractiveGeneration):
35
+ def __init__(self, args):
36
+ super().__init__(
37
+ {
38
+ "caption": str,
39
+ "width": int,
40
+ "negative_prompt": str,
41
+ "start_seed": int,
42
+ "end_seed": int,
43
+ "num_inference_steps": int,
44
+ "guidance_scale": float
45
+ },
46
+ default_parameters={
47
+ "width": width, #common_settings.MARIO_WIDTH,
48
+ "start_seed": 1,
49
+ "end_seed": 1, # Will be set to start_seed if blank
50
+ "num_inference_steps": common_settings.NUM_INFERENCE_STEPS,
51
+ "guidance_scale": common_settings.GUIDANCE_SCALE,
52
+ "caption": "",
53
+ "negative_prompt": ""
54
+ }
55
+ )
56
+
57
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ self.pipe = get_pipeline(args.model_path).to(self.device)
59
+ self.pipe.print_unet_architecture()
60
+ #self.pipe.save_unet_architecture_pdf(height, width)
61
+
62
+ if args.automatic_negative_captions or not self.pipe.supports_negative_prompt:
63
+ self.input_parameters.pop('negative_prompt', None)
64
+ self.default_parameters.pop('negative_prompt', None)
65
+
66
+ if args.automatic_negative_captions and not self.pipe.supports_negative_prompt:
67
+ raise ValueError("Automatic negative caption generation is not possible with a model that doesn't support it")
68
+
69
+ if args.tileset:
70
+ _, self.id_to_char, self.char_to_id, self.tile_descriptors = extract_tileset(args.tileset)
71
+
72
+ self.args = args
73
+
74
+ if self.args.game == "LR":
75
+ del self.input_parameters["width"]
76
+
77
+ print(f"Tileset in use: {self.args.tileset}")
78
+
79
+ def generate_image(self, param_values, generator, **extra_params):
80
+ if self.args.automatic_negative_captions:
81
+ pos, neg = positive_negative_caption_split(param_values["caption"], True)
82
+ param_values["negative_prompt"] = neg
83
+ images = self.pipe(
84
+ generator=generator,
85
+ **param_values
86
+ ).images
87
+
88
+ # Convert to indices
89
+ sample_tensor = images[0].unsqueeze(0)
90
+ sample_indices = convert_to_level_format(sample_tensor)
91
+
92
+ # Add level data to the list
93
+ scene = sample_indices[0].tolist()
94
+ if self.args.game == "LR":
95
+ number_of_tiles = common_settings.LR_TILE_COUNT
96
+ scene = [[x % number_of_tiles for x in row] for row in scene]
97
+
98
+ # Assign a caption to the sceneof whichever game is being played
99
+ if self.args.game == "Mario":
100
+ actual_caption = assign_caption(scene, self.id_to_char, self.char_to_id, self.tile_descriptors, False, self.args.describe_absence)
101
+ level_width = common_settings.MARIO_WIDTH
102
+ elif self.args.game == "LR":
103
+ actual_caption = lr_assign_caption(scene, self.id_to_char, self.char_to_id, self.tile_descriptors, False, self.args.describe_absence)
104
+ level_width = common_settings.LR_WIDTH
105
+ else:
106
+ raise ValueError(f"Unknown game: {self.args.game}")
107
+
108
+ if args.game == "LR":
109
+ print(f"Describe resulting image: {actual_caption}")
110
+ lr_compare_score = lr_compare_captions(param_values.get("caption", ""), actual_caption)
111
+ print(f"Comparison score: {lr_compare_score}")
112
+
113
+ # Use the new function to process scene segments
114
+ average_score, segment_captions, segment_scores = lr_process_scene_segments(
115
+ scene=scene,
116
+ segment_width=level_width,
117
+ prompt=param_values.get("caption", ""),
118
+ id_to_char=self.id_to_char,
119
+ char_to_id=self.char_to_id,
120
+ tile_descriptors=self.tile_descriptors,
121
+ describe_locations=False, #self.args.describe_locations,
122
+ describe_absence=self.args.describe_absence,
123
+ verbose=True
124
+ )
125
+
126
+ elif args.game == "Mario":
127
+ compare_score = compare_captions(param_values.get("caption", ""), actual_caption)
128
+ print(f"Comparison score: {compare_score}")
129
+
130
+ # Use the new function to process scene segments
131
+ average_score, segment_captions, segment_scores = process_scene_segments(
132
+ scene=scene,
133
+ segment_width=level_width,
134
+ prompt=param_values.get("caption", ""),
135
+ id_to_char=self.id_to_char,
136
+ char_to_id=self.char_to_id,
137
+ tile_descriptors=self.tile_descriptors,
138
+ describe_locations=False, #self.args.describe_locations,
139
+ describe_absence=self.args.describe_absence,
140
+ verbose=True
141
+ )
142
+
143
+ # Ask if user wants to play level
144
+ play_level = input("Do you want to play this level? (y/n): ").strip().lower()
145
+ if play_level == 'y':
146
+ print("Playing level...")
147
+ char_grid = scene_to_ascii(scene, self.id_to_char, False)
148
+ level = SampleOutput(level=char_grid, use_snes_graphics=False)
149
+ console_output = level.run_astar()
150
+ print(console_output)
151
+ elif play_level == 'n':
152
+ print("Level not played.")
153
+ else:
154
+ raise ValueError(f"Unknown input: {play_level}")
155
+
156
+ return visualize_samples(images)
157
+
158
+ def get_extra_params(self, param_values):
159
+ if "negative_prompt" in param_values and param_values["negative_prompt"] == "":
160
+ del param_values["negative_prompt"]
161
+
162
+ if param_values["caption"] == "":
163
+ del param_values["caption"]
164
+
165
+ param_values["output_type"] = "tensor"
166
+
167
+ # Lode Runner
168
+ if self.args.game == "LR":
169
+ param_values["height"] = common_settings.LR_HEIGHT
170
+ param_values["width"] = common_settings.LR_WIDTH
171
+
172
+ return dict()
173
+
174
+ if __name__ == "__main__":
175
+ args = parse_args()
176
+
177
+ if args.game == "Mario":
178
+ args.num_tiles = common_settings.MARIO_TILE_COUNT
179
+ height = common_settings.MARIO_HEIGHT
180
+ width = common_settings.MARIO_WIDTH
181
+ args.tile_size = common_settings.MARIO_TILE_PIXEL_DIM
182
+ args.tileset = '..\TheVGLC\Super Mario Bros\smb.json'
183
+ elif args.game == "LR":
184
+ args.num_tiles = common_settings.LR_TILE_COUNT
185
+ height = common_settings.LR_HEIGHT
186
+ width = common_settings.LR_WIDTH
187
+ args.tile_size = common_settings.LR_TILE_PIXEL_DIM
188
+ args.tileset = '..\TheVGLC\Lode Runner\Loderunner.json'
189
+ else:
190
+ raise ValueError(f"Unknown game: {args.game}")
191
+
192
+ ig = InteractiveLevelGeneration(args)
193
+ ig.start()
194
+
tokenizer.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from collections import Counter
4
+ import pickle
5
+ import argparse
6
+
7
+ class Tokenizer:
8
+ def __init__(self):
9
+ self.special_tokens = ["[PAD]", "[MASK]"]
10
+ self.vocab = {}
11
+ self.token_to_id = {}
12
+ self.id_to_token = {}
13
+
14
+ def tokenize(self, text):
15
+ # Match words, numbers, periods, and commas as separate tokens
16
+ tokens = re.findall(r'\w+|[.,]|\[mask\]|\[pad\]', text.lower())
17
+ # Restore MASK and PAD to all caps
18
+ modified_list = []
19
+ for s in tokens:
20
+ modified_s = s.replace("[mask]", "[MASK]").replace("[pad]", "[PAD]")
21
+ modified_list.append(modified_s)
22
+ return modified_list
23
+
24
+ def pad_sequence(self, tokens, length):
25
+ """Pads tokenized sequences to length with a padding token (assumed to be '[PAD]')."""
26
+ if len(tokens) > length:
27
+ raise ValueError(f"Token sequence length {len(tokens)} exceeds specified length {length}.")
28
+
29
+ pad_token = self.token_to_id["[PAD]"]
30
+ return tokens + [pad_token] * (length - len(tokens))
31
+
32
+ def build_vocab(self, dataset_path, min_freq=1):
33
+ token_counter = Counter()
34
+
35
+ with open(dataset_path, 'r') as f:
36
+ data = json.load(f)
37
+ for entry in data:
38
+ caption = entry['caption']
39
+ tokens = self.tokenize(caption)
40
+ token_counter.update(tokens)
41
+
42
+ # Keep tokens that meet the min frequency
43
+ tokens = [tok for tok, count in token_counter.items() if count >= min_freq]
44
+
45
+ # Ensure special tokens are always included
46
+ all_tokens = self.special_tokens + sorted(tokens)
47
+
48
+ # Build vocab dictionaries
49
+ self.vocab = {tok: idx for idx, tok in enumerate(all_tokens)}
50
+ self.token_to_id = self.vocab
51
+ self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
52
+
53
+ print(f"Vocabulary size: {len(self.vocab)}")
54
+
55
+ def encode(self, text):
56
+ tokens = self.tokenize(text)
57
+ encoded = []
58
+ for tok in tokens:
59
+ if tok not in self.token_to_id:
60
+ raise ValueError(f"Unknown token encountered: {tok} in {text}")
61
+ encoded.append(self.token_to_id[tok])
62
+ return encoded
63
+
64
+ def encode_batch(self, texts, pad_to_length=None):
65
+ """
66
+ Encode a batch of texts into token IDs with padding to ensure uniform length.
67
+
68
+ Args:
69
+ texts (list): A list of strings to encode
70
+ pad_to_length (int, optional): Length to pad all sequences to. If None,
71
+ will pad to the length of the longest sequence.
72
+
73
+ Returns:
74
+ list: A list of lists, where each inner list contains the token IDs for a text
75
+ """
76
+ # Get the padding token ID
77
+ pad_token = self.token_to_id["[PAD]"]
78
+
79
+ # First encode all texts
80
+ encoded_texts = []
81
+ for text in texts:
82
+ try:
83
+ encoded = self.encode(text)
84
+ encoded_texts.append(encoded)
85
+ except ValueError as e:
86
+ raise ValueError(f"Error encoding text: {text}. {str(e)}")
87
+
88
+ # Determine padding length
89
+ if pad_to_length is None:
90
+ pad_to_length = max(len(seq) for seq in encoded_texts)
91
+
92
+ # Pad sequences to uniform length
93
+ padded_texts = []
94
+ for seq in encoded_texts:
95
+ if len(seq) > pad_to_length:
96
+ # Truncate if too long
97
+ padded_texts.append(seq[:pad_to_length])
98
+ else:
99
+ # Pad if too short
100
+ padding = [pad_token] * (pad_to_length - len(seq))
101
+ padded_texts.append(seq + padding)
102
+
103
+ return padded_texts
104
+
105
+ def decode(self, token_ids):
106
+ return ' '.join(self.id_to_token[tok_id] for tok_id in token_ids)
107
+
108
+ def save(self, path):
109
+ with open(path, 'wb') as f:
110
+ pickle.dump({'vocab': self.vocab}, f)
111
+
112
+ def load(self, path):
113
+ with open(path, 'rb') as f:
114
+ data = pickle.load(f)
115
+ self.vocab = data['vocab']
116
+ self.token_to_id = self.vocab
117
+ self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
118
+
119
+ def get_vocab(self):
120
+ return sorted(self.vocab.keys())
121
+
122
+ def get_vocab_size(self):
123
+ return len(self.vocab)
124
+
125
+ if __name__ == "__main__":
126
+ tokenizer = Tokenizer()
127
+
128
+ parser = argparse.ArgumentParser(description="Tokenizer utility for saving and loading vocabularies.")
129
+ parser.add_argument("action", choices=["save", "load"], help="Action to perform: 'save' or 'load'.")
130
+ parser.add_argument("--json_file", type=str, default='Mario_LevelsAndCaptions.json', help="Path to the JSON file containing the dataset (required for 'save').")
131
+ parser.add_argument("--pkl_file", type=str, default='Mario_Tokenizer.pkl', help="Path to the pickle file to save/load the tokenizer.")
132
+
133
+ args = parser.parse_args()
134
+
135
+ if args.action == "save":
136
+ if not args.json_file:
137
+ raise ValueError("The --json_file argument is required for the 'save' action.")
138
+ tokenizer.build_vocab(args.json_file)
139
+ tokenizer.save(args.pkl_file)
140
+ elif args.action == "load":
141
+ tokenizer.load(args.pkl_file)
142
+
143
+ # Example usage
144
+ #print(tokenizer.encode("floor with one gap. one enemy."))
145
+ #print(tokenizer.get_vocab())
146
+ #for id, token in tokenizer.id_to_token.items():
147
+ # print(id,":",token)
util.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ import os
4
+ from collections import Counter
5
+
6
+ # This file contains utility functions for analyzing and describing levels in both Lode Runner and Super Mario Bros.
7
+
8
+ # Could define these via the command line, but for now they are hardcoded
9
+ coarse_locations = True
10
+ coarse_counts = True
11
+ pluralize = True
12
+ give_staircase_lengths = False
13
+
14
+ def describe_size(count):
15
+ if count <= 4: return "small"
16
+ else: return "big"
17
+
18
+ def describe_quantity(count):
19
+ if count == 0: return "no"
20
+ elif count == 1: return "one"
21
+ elif count == 2: return "two"
22
+ elif count < 5: return "a few"
23
+ elif count < 10: return "several"
24
+ else: return "many"
25
+
26
+ def get_tile_descriptors(tileset):
27
+ """Creates a mapping from tile character to its list of descriptors."""
28
+ result = {char: set(attrs) for char, attrs in tileset["tiles"].items()}
29
+ # Fake tiles. Should these contain anything? Note that code elsewhere expects everything to be passable or solid
30
+ result["!"] = {"passable"}
31
+ result["*"] = {"passable"}
32
+ return result
33
+
34
+ def analyze_floor(scene, id_to_char, tile_descriptors, describe_absence):
35
+ """Analyzes the last row of the 32x32 scene and generates a floor description."""
36
+ WIDTH = len(scene[0])
37
+ last_row = scene[-1] # The FLOOR row of the scene
38
+ solid_count = sum(
39
+ 1 for tile in last_row
40
+ if tile in id_to_char and (
41
+ "solid" in tile_descriptors.get(id_to_char[tile], []) or
42
+ "diggable" in tile_descriptors.get(id_to_char[tile], [])
43
+ )
44
+ )
45
+ passable_count = sum(
46
+ 1 for tile in last_row if "passable" in tile_descriptors.get(id_to_char[tile], [])
47
+ )
48
+
49
+ if solid_count == WIDTH:
50
+ return "full floor"
51
+ elif passable_count == WIDTH:
52
+ if describe_absence:
53
+ return "no floor"
54
+ else:
55
+ return ""
56
+ elif solid_count > passable_count:
57
+ # Count contiguous groups of passable tiles
58
+ gaps = 0
59
+ in_gap = False
60
+ for tile in last_row:
61
+ # Enemies are also a gap since they immediately fall into the gap
62
+ if "passable" in tile_descriptors.get(id_to_char[tile], []) or "enemy" in tile_descriptors.get(id_to_char[tile], []):
63
+ if not in_gap:
64
+ gaps += 1
65
+ in_gap = True
66
+ elif "solid" in tile_descriptors.get(id_to_char[tile], []):
67
+ in_gap = False
68
+ else:
69
+ print("error")
70
+ print(tile)
71
+ print(id_to_char[tile])
72
+ print(tile_descriptors)
73
+ print(tile_descriptors.get(id_to_char[tile], []))
74
+ raise ValueError("Every tile should be passable, solid, or enemy")
75
+ return f"floor with {describe_quantity(gaps) if coarse_counts else gaps} gap" + ("s" if pluralize and gaps != 1 else "")
76
+ else:
77
+ # Count contiguous groups of solid tiles
78
+ chunks = 0
79
+ in_chunk = False
80
+ for tile in last_row:
81
+ if "solid" in tile_descriptors.get(id_to_char[tile], []):
82
+ if not in_chunk:
83
+ chunks += 1
84
+ in_chunk = True
85
+ elif "passable" in tile_descriptors.get(id_to_char[tile], []) or "enemy" in tile_descriptors.get(id_to_char[tile], []):
86
+ in_chunk = False
87
+ else:
88
+ print("error")
89
+ print(tile)
90
+ print(tile_descriptors)
91
+ print(tile_descriptors.get(tile, []))
92
+ raise ValueError("Every tile should be either passable or solid")
93
+ return f"giant gap with {describe_quantity(chunks) if coarse_counts else chunks} chunk"+("s" if pluralize and chunks != 1 else "")+" of floor"
94
+
95
+ def count_in_scene(scene, tiles, exclude=set()):
96
+ """ counts standalone tiles, unless they are in the exclude set """
97
+ count = 0
98
+ for r, row in enumerate(scene):
99
+ for c, t in enumerate(row):
100
+ #if exclude and t in tiles: print(r,c, exclude)
101
+ if (r,c) not in exclude and t in tiles:
102
+ #if exclude: print((r,t), exclude, (r,t) in exclude)
103
+ count += 1
104
+ #if exclude: print(tiles, exclude, count)
105
+ return count
106
+
107
+ def count_caption_phrase(scene, tiles, name, names, offset = 0, describe_absence=False, exclude=set()):
108
+ """ offset modifies count used in caption """
109
+ count = offset + count_in_scene(scene, tiles, exclude)
110
+ #if name == "loose block": print("count", count)
111
+ if count > 0:
112
+ return f" {describe_quantity(count) if coarse_counts else count} " + (names if pluralize and count > 1 else name) + "."
113
+ elif describe_absence:
114
+ return f" no {names}."
115
+ else:
116
+ return ""
117
+
118
+ def in_column(scene, x, tile):
119
+ for row in scene:
120
+ if row[x] == tile:
121
+ return True
122
+
123
+ return False
124
+
125
+ def analyze_ceiling(scene, id_to_char, tile_descriptors, describe_absence, ceiling_row = 1):
126
+ """
127
+ Analyzes ceiling row (0-based index) to detect a ceiling.
128
+ Returns a caption phrase or an empty string if no ceiling is detected.
129
+ """
130
+ WIDTH = len(scene[0])
131
+
132
+ row = scene[ceiling_row]
133
+ solid_count = sum(1 for tile in row if "solid" in tile_descriptors.get(id_to_char[tile], []))
134
+
135
+ if solid_count == WIDTH:
136
+ return " full ceiling."
137
+ elif solid_count > WIDTH//2:
138
+ # Count contiguous gaps of passable tiles
139
+ gaps = 0
140
+ in_gap = False
141
+ for tile in row:
142
+ # Enemies are also a gap since they immediately fall into the gap, but they are marked as "moving" and not "passable"
143
+ if "passable" in tile_descriptors.get(id_to_char[tile], []) or "moving" in tile_descriptors.get(id_to_char[tile], []):
144
+ if not in_gap:
145
+ gaps += 1
146
+ in_gap = True
147
+ else:
148
+ in_gap = False
149
+ result = f" ceiling with {describe_quantity(gaps) if coarse_counts else gaps} gap" + ("s" if pluralize and gaps != 1 else "") + "."
150
+
151
+ # Adding the "moving" check should make this code unnecessary
152
+ #if result == ' ceiling with no gaps.':
153
+ # print("This should not happen: ceiling with no gaps")
154
+ # print("ceiling_row:", scene[ceiling_row])
155
+ # result = " full ceiling."
156
+
157
+ return result
158
+ elif describe_absence:
159
+ return " no ceiling."
160
+ else:
161
+ return "" # Not enough solid tiles for a ceiling
162
+
163
+ def extract_tileset(tileset_path):
164
+ # Load tileset
165
+ with open(tileset_path, "r") as f:
166
+ tileset = json.load(f)
167
+ #print(f"tileset: {tileset}")
168
+ tile_chars = sorted(tileset['tiles'].keys())
169
+ # Wiggle room for the tileset to be a bit more flexible.
170
+ # However, this requires me to add some bogus tiles to the list.
171
+ # tile_chars.append('!')
172
+ # tile_chars.append('*')
173
+ #print(f"tile_chars: {tile_chars}")
174
+ id_to_char = {idx: char for idx, char in enumerate(tile_chars)}
175
+ #print(f"id_to_char: {id_to_char}")
176
+ char_to_id = {char: idx for idx, char in enumerate(tile_chars)}
177
+ #print(f"char_to_id: {char_to_id}")
178
+ tile_descriptors = get_tile_descriptors(tileset)
179
+ #print(f"tile_descriptors: {tile_descriptors}")
180
+
181
+ return tile_chars, id_to_char, char_to_id, tile_descriptors
182
+
183
+ def flood_fill(scene, visited, start_row, start_col, id_to_char, tile_descriptors, excluded, pipes=False, target_descriptor=None):
184
+ stack = [(start_row, start_col)]
185
+ structure = []
186
+
187
+ while stack:
188
+ row, col = stack.pop()
189
+ if (row, col) in visited or (row, col) in excluded:
190
+ continue
191
+ tile = scene[row][col]
192
+ descriptors = tile_descriptors.get(id_to_char[tile], [])
193
+ # Use target_descriptor if provided, otherwise default to old solid/pipe logic
194
+ if target_descriptor is not None:
195
+ if target_descriptor not in descriptors:
196
+ continue
197
+ else:
198
+ if "solid" not in descriptors or (not pipes and "pipe" in descriptors) or (pipes and "pipe" not in descriptors):
199
+ continue
200
+
201
+ visited.add((row, col))
202
+ structure.append((row, col))
203
+
204
+ # Check neighbors
205
+ for d_row, d_col in [(-1,0), (1,0), (0,-1), (0,1)]:
206
+ # Weird special case for adjacent pipes
207
+ if (id_to_char[tile] == '>' or id_to_char[tile] == ']') and d_col == 1: # if on the right edge of a pipe
208
+ continue # Don't go right if on the right edge of a pipe
209
+ if (id_to_char[tile] == '<' or id_to_char[tile] == '[') and d_col == -1: # if on the left edge of a pipe
210
+ continue # Don't go left if on the left edge of a pipe
211
+
212
+ n_row, n_col = row + d_row, col + d_col
213
+ if 0 <= n_row < len(scene) and 0 <= n_col < len(scene[0]):
214
+ stack.append((n_row, n_col))
215
+
216
+ return structure