schrum2 commited on
Commit
1d027f9
·
verified ·
1 Parent(s): 8b0893c

Code that is apparently needed by the pipeline

Browse files

Probably more than is needed, but we'll see.

Some files had to be modified from their original repo versions

captions/caption_match.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from create_ascii_captions import assign_caption
2
+
3
+ # Quantity order for scoring partial matches
4
+ QUANTITY_TERMS = ["one", "two", "a few", "several", "many"]
5
+
6
+ # Topics to compare
7
+ TOPIC_KEYWORDS = [
8
+ #"giant gap", # I think all gaps are subsumed by the floor topic
9
+ "floor", "ceiling",
10
+ "broken pipe", "upside down pipe", "pipe",
11
+ "coin line", "coin",
12
+ "platform", "tower", #"wall",
13
+ "broken cannon", "cannon",
14
+ "ascending staircase", "descending staircase",
15
+ "rectangular",
16
+ "irregular",
17
+ "question block", "loose block",
18
+ "enem" # catch "enemy"/"enemies"
19
+ ]
20
+
21
+ # Need list because the order matters
22
+ KEYWORD_TO_NEGATED_PLURAL = [
23
+ (" broken pipe.", ""), # If not the first phrase
24
+ ("broken pipe. ", ""), # If the first phrase (after removing all others)
25
+ (" broken cannon.", ""), # If not the first phrase
26
+ ("broken cannon. ", ""), # If the first phrase (after removing all others)
27
+ ("pipe", "pipes"),
28
+ ("cannon", "cannons"),
29
+ ("platform", "platforms"),
30
+ ("tower", "towers"),
31
+ ("staircase", "staircases"),
32
+ ("enem", "enemies"),
33
+ ("rectangular", "rectangular block clusters"),
34
+ ("irregular", "irregular block clusters"),
35
+ ("coin line", "coin lines"),
36
+ ("coin.", "coins."), # Need period to avoid matching "coin line"
37
+ ("question block", "question blocks"),
38
+ ("loose block", "loose blocks")
39
+ ]
40
+
41
+ BROKEN_TOPICS = 2 # Number of topics that are considered "broken" (e.g., "broken pipe", "broken cannon")
42
+
43
+ # Plural normalization map (irregulars)
44
+ PLURAL_EXCEPTIONS = {
45
+ "enemies": "enemy",
46
+ }
47
+
48
+ def normalize_plural(phrase):
49
+ # Normalize known irregular plurals
50
+ for plural, singular in PLURAL_EXCEPTIONS.items():
51
+ phrase = phrase.replace(plural, singular)
52
+
53
+ # Normalize regular plurals (basic "s" endings)
54
+ words = phrase.split()
55
+ normalized_words = []
56
+ for word in words:
57
+ if word.endswith('s') and not word.endswith('ss'): # avoid "class", "boss"
58
+ singular = word[:-1]
59
+ normalized_words.append(singular)
60
+ else:
61
+ normalized_words.append(word)
62
+ return ' '.join(normalized_words)
63
+
64
+ def extract_phrases(caption, debug=False):
65
+ phrases = [phrase.strip() for phrase in caption.split('.') if phrase.strip()]
66
+ topic_to_phrase = {}
67
+ already_matched_phrases = set() # Track phrases that have been matched
68
+
69
+ for topic in TOPIC_KEYWORDS:
70
+ matching_phrases = []
71
+
72
+ for p in phrases:
73
+ # Only consider phrases that haven't been matched to longer topics
74
+ if topic in p and p not in already_matched_phrases:
75
+ matching_phrases.append(p)
76
+
77
+ if matching_phrases:
78
+ # Filter out "no ..." phrases as equivalent to absence
79
+ phrase = matching_phrases[0]
80
+ if phrase.lower().startswith("no "):
81
+ topic_to_phrase[topic] = None
82
+ if debug:
83
+ print(f"[Extract] Topic '{topic}': detected 'no ...', treating as None")
84
+ else:
85
+ topic_to_phrase[topic] = phrase
86
+ already_matched_phrases.add(phrase) # Mark this phrase as matched
87
+ if debug:
88
+ print(f"[Extract] Topic '{topic}': found phrase '{phrase}'")
89
+ else:
90
+ topic_to_phrase[topic] = None
91
+ if debug:
92
+ print(f"[Extract] Topic '{topic}': no phrase found")
93
+
94
+ return topic_to_phrase
95
+
96
+ def quantity_score(phrase1, phrase2, debug=False):
97
+ def find_quantity(phrase):
98
+ for term in QUANTITY_TERMS:
99
+ if term in phrase:
100
+ return term
101
+ return None
102
+
103
+ qty1 = find_quantity(phrase1)
104
+ qty2 = find_quantity(phrase2)
105
+
106
+ if debug:
107
+ print(f"[Quantity] Comparing quantities: '{qty1}' vs. '{qty2}'")
108
+
109
+ if qty1 and qty2:
110
+ idx1 = QUANTITY_TERMS.index(qty1)
111
+ idx2 = QUANTITY_TERMS.index(qty2)
112
+ diff = abs(idx1 - idx2)
113
+ max_diff = len(QUANTITY_TERMS) - 1
114
+ score = 1.0 - (diff / max_diff)
115
+ if debug:
116
+ print(f"[Quantity] Quantity indices: {idx1} vs. {idx2}, diff: {diff}, score: {score:.2f}")
117
+ return score
118
+ if debug:
119
+ print("[Quantity] At least one quantity missing, assigning partial score 0.1")
120
+ return 0.1
121
+
122
+ def compare_captions(correct_caption, generated_caption, debug=False, return_matches=False):
123
+ correct_phrases = extract_phrases(correct_caption, debug=debug)
124
+ generated_phrases = extract_phrases(generated_caption, debug=debug)
125
+
126
+ total_score = 0.0
127
+ num_topics = len(TOPIC_KEYWORDS)
128
+
129
+ exact_matches = []
130
+ partial_matches = []
131
+ excess_phrases = []
132
+
133
+ if debug:
134
+ print("\n--- Starting Topic Comparison ---\n")
135
+
136
+ for topic in TOPIC_KEYWORDS:
137
+ correct = correct_phrases[topic]
138
+ generated = generated_phrases[topic]
139
+
140
+ if debug:
141
+ print(f"[Topic: {topic}] Correct: {correct} | Generated: {generated}")
142
+
143
+ if correct is None and generated is None:
144
+ total_score += 1.0
145
+ if debug:
146
+ print(f"[Topic: {topic}] Both None — full score: 1.0\n")
147
+ elif correct is None or generated is None:
148
+ total_score += -1.0
149
+ if generated is not None: # Considered an excess phrase
150
+ excess_phrases.append(generated)
151
+ if debug:
152
+ print(f"[Topic: {topic}] One is None — penalty: -1.0\n")
153
+ else:
154
+ # Normalize pluralization before comparison
155
+ norm_correct = normalize_plural(correct)
156
+ norm_generated = normalize_plural(generated)
157
+
158
+ if debug:
159
+ print(f"[Topic: {topic}] Normalized: Correct: '{norm_correct}' | Generated: '{norm_generated}'")
160
+
161
+ if norm_correct == norm_generated:
162
+ total_score += 1.0
163
+ exact_matches.append(generated)
164
+ if debug:
165
+ print(f"[Topic: {topic}] Exact match — score: 1.0\n")
166
+ elif any(term in norm_correct for term in QUANTITY_TERMS) and any(term in norm_generated for term in QUANTITY_TERMS):
167
+ qty_score = quantity_score(norm_correct, norm_generated, debug=debug)
168
+ total_score += qty_score
169
+ partial_matches.append(generated)
170
+ if debug:
171
+ print(f"[Topic: {topic}] Quantity-based partial score: {qty_score:.2f}\n")
172
+ else:
173
+ total_score += 0.1
174
+ partial_matches.append(generated)
175
+ if debug:
176
+ print(f"[Topic: {topic}] Partial match (topic overlap) — score: 0.1\n")
177
+
178
+ if debug:
179
+ print(f"[Topic: {topic}] Current total score: {total_score:.4f}\n")
180
+
181
+ if debug:
182
+ print("total_score before normalization:", total_score)
183
+ print(f"Number of topics: {num_topics}")
184
+
185
+ final_score = total_score / num_topics
186
+ if debug:
187
+ print(f"--- Final score: {final_score:.4f} ---\n")
188
+
189
+ if return_matches:
190
+ return final_score, exact_matches, partial_matches, excess_phrases
191
+
192
+ return final_score
193
+
194
+ def process_scene_segments(scene, segment_width, prompt, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence, verbose=False):
195
+ """
196
+ Process a scene by partitioning it into segments, assigning captions, and computing comparison scores.
197
+
198
+ Args:
199
+ scene (list): The scene to process, represented as a 2D list.
200
+ segment_width (int): The width of each segment.
201
+ prompt (str): The prompt to compare captions against.
202
+ id_to_char (dict): Mapping from tile IDs to characters.
203
+ char_to_id (dict): Mapping from characters to tile IDs.
204
+ tile_descriptors (dict): Descriptions of individual tile types.
205
+ describe_locations (bool): Whether to include location descriptions in captions.
206
+ describe_absence (bool): Whether to indicate absence of items in captions.
207
+ verbose (bool): If True, print captions and scores for each segment.
208
+
209
+ Returns:
210
+ tuple: A tuple containing the average comparison score, captions for each segment, and scores for each segment.
211
+ """
212
+ # Partition the scene into segments of the specified width
213
+ segments = [
214
+ [row[i:i+segment_width] for row in scene] # Properly slice each row of the scene
215
+ for i in range(0, len(scene[0]), segment_width)
216
+ ]
217
+
218
+ # Assign captions and compute scores for each segment
219
+ segment_scores = []
220
+ segment_captions = []
221
+ for idx, segment in enumerate(segments):
222
+ segment_caption = assign_caption(segment, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence)
223
+ segment_score = compare_captions(prompt, segment_caption)
224
+ segment_scores.append(segment_score)
225
+ segment_captions.append(segment_caption)
226
+
227
+ if verbose:
228
+ print(f"Segment {idx + 1} caption: {segment_caption}")
229
+ print(f"Segment {idx + 1} comparison score: {segment_score}")
230
+
231
+ # Compute the average comparison score
232
+ average_score = sum(segment_scores) / len(segment_scores) if segment_scores else 0
233
+
234
+ if verbose:
235
+ print(f"Average comparison score across all segments: {average_score}")
236
+
237
+ return average_score, segment_captions, segment_scores
238
+
239
+ if __name__ == '__main__':
240
+
241
+ ref = "floor with one gap. two enemies. one platform. one tower."
242
+ gen = "giant gap with one chunk of floor. two enemies. one platform. one tower."
243
+
244
+ score = compare_captions(ref, gen, debug=True)
245
+ print(f"Should be: {ref}")
246
+ print(f" but was: {gen}")
247
+ print(f"Score: {score}")
captions/evaluate_caption_order_tolerance.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import os
4
+ import random
5
+ from collections import defaultdict
6
+ import sys, os
7
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8
+ import util.common_settings as common_settings # adjust import if needed
9
+ from level_dataset import LevelDataset, visualize_samples, colors, mario_tiles # adjust import if needed
10
+ from torch.utils.data import DataLoader
11
+ from evaluate_caption_adherence import calculate_caption_score_and_samples # adjust import if needed
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib
14
+ import json
15
+ from tqdm import tqdm
16
+
17
+ import numpy as np
18
+ import torch
19
+ from tqdm import tqdm
20
+
21
+ from captions.util import extract_tileset
22
+ from models.pipeline_loader import get_pipeline
23
+
24
+
25
+ def parse_args():
26
+ parser = argparse.ArgumentParser(description="Evaluate caption order tolerance for a diffusion model.")
27
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the trained diffusion model")
28
+ parser.add_argument("--caption", type=str, required=False, default=None, help="Caption to evaluate, phrases separated by periods")
29
+ parser.add_argument("--tileset", type=str, help="Path to the tileset JSON file")
30
+ #parser.add_argument("--json", type=str, default="datasets\\Test_for_caption_order_tolerance.json", help="Path to dataset json file")
31
+ #parser.add_argument("--json", type=str, default="datasets\\SMB1_LevelsAndCaptions-regular-test.json", help="Path to dataset json file")
32
+ parser.add_argument("--json", type=str, default="datasets\\Mar1and2_LevelsAndCaptions-regular.json", help="Path to dataset json file")
33
+ #parser.add_argument("--trials", type=int, default=3, help="Number of times to evaluate each caption permutation")
34
+ parser.add_argument("--inference_steps", type=int, default=common_settings.NUM_INFERENCE_STEPS)
35
+ parser.add_argument("--guidance_scale", type=float, default=common_settings.GUIDANCE_SCALE)
36
+ parser.add_argument("--seed", type=int, default=42)
37
+ parser.add_argument("--game", type=str, choices=["Mario", "LR"], default="Mario", help="Game to evaluate (Mario or Lode Runner)")
38
+ parser.add_argument("--describe_absence", action="store_true", default=False, help="Indicate when there are no occurrences of an item or structure")
39
+ parser.add_argument("--save_as_json", action="store_true", help="Save generated levels as JSON")
40
+ parser.add_argument("--output_dir", type=str, default="visualizations", help="Output directory if not comparing checkpoints (subdir of model directory)")
41
+ parser.add_argument("--max_permutations", type=int, default=5, help="Maximum amount of permutations that can be made per caption")
42
+ return parser.parse_args()
43
+
44
+
45
+ def setup_environment(seed):
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ random.seed(seed)
48
+ np.random.seed(seed)
49
+ torch.manual_seed(seed)
50
+ if torch.cuda.is_available():
51
+ torch.cuda.manual_seed_all(seed)
52
+ return device
53
+
54
+ def load_captions_from_json(json_path):
55
+ with open(json_path, 'r', encoding='utf-8') as f:
56
+ data = json.load(f)
57
+ # If the JSON is a list of dicts with a "caption" key
58
+ captions = [entry["caption"] for entry in data if "caption" in entry]
59
+ return captions
60
+
61
+ def creation_of_parameters(caption, max_permutations):
62
+ args = parse_args()
63
+ device = setup_environment(args.seed)
64
+
65
+ if args.game == "Mario":
66
+ num_tiles = common_settings.MARIO_TILE_COUNT
67
+ tileset = '..\TheVGLC\Super Mario Bros\smb.json'
68
+ elif args.game == "LR":
69
+ num_tiles = common_settings.LR_TILE_COUNT
70
+ tileset = '..\TheVGLC\Lode Runner\Loderunner.json'
71
+ else:
72
+ raise ValueError(f"Unknown game: {args.game}")
73
+
74
+ # Load pipeline
75
+ pipe = get_pipeline(args.model_path).to(device)
76
+
77
+ # Load tile metadata
78
+ tile_chars, id_to_char, char_to_id, tile_descriptors = extract_tileset(tileset)
79
+
80
+ perm_captions = []
81
+ if isinstance(caption, list):
82
+ # captions is a list of caption strings
83
+ phrases_per_caption = [
84
+ [p.strip() for p in cap.split('.') if p.strip()]
85
+ for cap in caption
86
+ ]
87
+ permutations = []
88
+ for phrases in phrases_per_caption:
89
+ perms = list(itertools.permutations(phrases))
90
+ if len(perms) > max_permutations:
91
+ perms = random.sample(perms, max_permutations)
92
+ permutations.append(perms)
93
+ perm_captions = ['.'.join(perm) + '.' for perms in permutations for perm in perms]
94
+ elif isinstance(caption, str):
95
+ # Split caption into phrases and get all permutations
96
+ phrase = [p.strip() for p in caption.split('.') if p.strip()]
97
+ permutations_cap = []
98
+ perms = list(itertools.permutations(phrase))
99
+ if len(perms) > max_permutations:
100
+ perms = random.sample(perms, max_permutations)
101
+ permutations_cap.append(perms)
102
+
103
+ perm_captions = ['.'.join(perm) + '.' for perms in permutations_cap for perm in perms]
104
+
105
+ # Create a list of dicts as expected by LevelDataset
106
+ caption_data = [{"scene": None, "caption": cap} for cap in perm_captions]
107
+
108
+ # Initialize dataset
109
+ dataset = LevelDataset(
110
+ data_as_list=caption_data,
111
+ shuffle=False,
112
+ mode="text",
113
+ augment=False,
114
+ num_tiles=common_settings.MARIO_TILE_COUNT,
115
+ negative_captions=False,
116
+ block_embeddings=None
117
+ )
118
+
119
+ # Create dataloader
120
+ dataloader = DataLoader(
121
+ dataset,
122
+ batch_size=min(16, len(perm_captions)),
123
+ shuffle=False,
124
+ num_workers=4,
125
+ drop_last=False,
126
+ persistent_workers=True
127
+ )
128
+
129
+
130
+ return pipe, device, id_to_char, char_to_id, tile_descriptors, num_tiles, dataloader, perm_captions, caption_data
131
+
132
+ def statistics_of_captions(captions, dataloader, compare_all_scores, pipe=None, device=None, id_to_char=None, char_to_id=None, tile_descriptors=None, num_tiles=None):
133
+ """
134
+ Calculate statistics of the captions.
135
+ Returns average, standard deviation, minimum, maximum, and median of caption scores.
136
+ """
137
+ args = parse_args()
138
+ if not captions:
139
+ print("No captions found in the provided JSON file.")
140
+ return
141
+ print(f"\nLoaded {len(captions)} captions from {args.json}")
142
+
143
+
144
+ avg_score = np.mean(compare_all_scores)
145
+ std_dev_score = np.std(compare_all_scores)
146
+ min_score = np.min(compare_all_scores)
147
+ max_score = np.max(compare_all_scores)
148
+ median_score = np.median(compare_all_scores)
149
+
150
+ print("\n-----Scores for each caption permutation-----")
151
+ for i, score in enumerate(compare_all_scores):
152
+ print(f"Scores for caption {i + 1}:", score)
153
+
154
+ print("\n-----Statistics of captions-----")
155
+ print(f"Average score: {avg_score:.4f}")
156
+ print(f"Standard deviation: {std_dev_score:.4f}")
157
+ print(f"Minimum score: {min_score:.4f}")
158
+ print(f"Maximum score: {max_score:.4f}")
159
+ print(f"Median score: {median_score:.4f}")
160
+
161
+ return compare_all_scores, avg_score, std_dev_score, min_score, max_score, median_score
162
+
163
+ def main():
164
+ args = parse_args()
165
+ if args.caption is None or args.caption == "":
166
+ caption = load_captions_from_json(args.json)
167
+ else:
168
+ caption = args.caption
169
+ #caption = ("many pipes. many coins. , many enemies. many blocks. , many platforms. many question blocks.").split(',')
170
+
171
+ all_scores = []
172
+ all_avg_scores = []
173
+ all_std_dev_scores = []
174
+ all_min_scores = []
175
+ all_max_scores = []
176
+ all_median_scores = []
177
+ all_captions = [item.strip() for s in caption for item in s.split(",")]
178
+
179
+ one_caption = []
180
+ count = 0
181
+
182
+ output_jsonl_path = os.path.join(args.output_dir, "evaluation_caption_order_results.jsonl")
183
+ with open(output_jsonl_path, "w") as f:
184
+ for cap in all_captions:
185
+ one_caption = cap
186
+
187
+ # Initialize dataset
188
+ pipe, device, id_to_char, char_to_id, tile_descriptors, num_tiles, dataloader, perm_caption, caption_data = creation_of_parameters(one_caption, args.max_permutations)
189
+ if not pipe:
190
+ print("Failed to create pipeline.")
191
+ return
192
+
193
+ avg_score, all_samples, all_prompts, compare_all_scores = calculate_caption_score_and_samples(device, pipe, dataloader, args.inference_steps, args.guidance_scale, args.seed, id_to_char, char_to_id, tile_descriptors, args.describe_absence, output=True, height=common_settings.MARIO_HEIGHT, width=common_settings.MARIO_WIDTH)
194
+ scores, avg_score, std_dev_score, min_score, max_score, median_score = statistics_of_captions(perm_caption, dataloader, compare_all_scores, pipe, device, id_to_char, char_to_id, tile_descriptors, num_tiles)
195
+
196
+ if args.save_as_json:
197
+ result_entry = {
198
+ "Caption": one_caption,
199
+ "Average score for all permutations": avg_score,
200
+ "Standard deviation": std_dev_score,
201
+ "Minimum score": min_score,
202
+ "Maximum score": max_score,
203
+ "Median score": median_score
204
+ #"samples": all_samples[i].tolist() if hasattr(all_samples, "__getitem__") else None,
205
+ #"prompt": all_prompts[i] if i < len(all_prompts) else "N/A"
206
+ }
207
+ f.write(json.dumps(result_entry) + "\n")
208
+
209
+ all_avg_scores.append(avg_score)
210
+
211
+ #scores, avg_score, std_dev_score, min_score, max_score, median_score = statistics_of_captions(perm_caption, dataloader, compare_all_scores, pipe, device, id_to_char, char_to_id, tile_descriptors, num_tiles)
212
+ for score in enumerate(scores):
213
+ all_scores.append(score)
214
+ all_std_dev_scores.append(std_dev_score)
215
+ all_min_scores.append(min_score)
216
+ all_max_scores.append(max_score)
217
+ all_median_scores.append(median_score)
218
+ if (count % 10) == 0:
219
+ f.flush() # Ensure each result is written immediately
220
+ os.fsync(f.fileno()) # Ensure file is flushed to disk
221
+ count = count + 1
222
+
223
+ print(f"\nAverage score across all captions: {avg_score:.4f}")
224
+
225
+
226
+
227
+ visualizations_dir = os.path.join(os.path.dirname(__file__), "visualizations")
228
+ if args.caption is not None or "":
229
+ caption_folder = args.caption.replace(" ", "_").replace(".", "_")
230
+ output_directory = os.path.join(visualizations_dir, caption_folder)
231
+
232
+ visualize_samples(
233
+ all_samples,
234
+ output_dir=output_directory,
235
+ prompts=all_prompts[0] if all_prompts else "No prompts available"
236
+ )
237
+ print(f"\nVisualizations saved to: {output_directory}")
238
+
239
+
240
+ print("\nAll samples shape:", all_samples.shape)
241
+ print("\nAll prompts:", all_prompts)
242
+
243
+ all_avg_score = np.mean(all_avg_scores)
244
+ all_std_dev_score = np.std(all_std_dev_scores)
245
+ all_min_score = np.min(all_min_scores)
246
+ all_max_score = np.max(all_max_scores)
247
+ all_median_score = np.median(all_median_scores)
248
+
249
+ if args.save_as_json:
250
+ output_jsonl_path = os.path.join(args.output_dir, "evaluation_caption_order_results.jsonl")
251
+ with open(output_jsonl_path, "w") as f:
252
+ if isinstance(caption, list) or (args.caption is None or args.caption == ""):
253
+ # Multiple captions (permuted)
254
+ for i, score in enumerate(all_avg_scores):
255
+ result_entry = {
256
+ "Caption": caption[i] if i < len(caption) else "N/A",
257
+ "Average score for all permutations": score,
258
+ #"samples": all_samples[i].tolist() if hasattr(all_samples, "__getitem__") else None,
259
+ #"prompt": all_prompts[i] if i < len(all_prompts) else "N/A"
260
+ }
261
+ f.write(json.dumps(result_entry) + "\n")
262
+ else:
263
+ # Single caption
264
+ result_entry = {
265
+ "caption": caption,
266
+ "avg_score": avg_score,
267
+ "samples": all_samples.tolist(),
268
+ "prompts": all_prompts
269
+ }
270
+ f.write(json.dumps(result_entry) + "\n")
271
+
272
+ results = {
273
+
274
+ "Scores of all captions": {
275
+ "Scores": all_scores,
276
+ "Number of captions": len(all_scores),
277
+ "Average of all permutations": all_avg_score,
278
+ "Standard deviation of all permutations": all_std_dev_score,
279
+ "Min score of all permutations": all_min_score,
280
+ "Max score of all permutations": all_max_score,
281
+ "Median score of all permutations": all_median_score
282
+ },
283
+ }
284
+ json.dump(results, f, indent=4)
285
+
286
+ print(f"Results saved to {output_jsonl_path}")
287
+ if __name__ == "__main__":
288
+ main()
captions/util.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ import os
4
+ from collections import Counter
5
+
6
+ # This file contains utility functions for analyzing and describing levels in both Lode Runner and Super Mario Bros.
7
+
8
+ # Could define these via the command line, but for now they are hardcoded
9
+ coarse_locations = True
10
+ coarse_counts = True
11
+ pluralize = True
12
+ give_staircase_lengths = False
13
+
14
+ def describe_size(count):
15
+ if count <= 4: return "small"
16
+ else: return "big"
17
+
18
+ def describe_quantity(count):
19
+ if count == 0: return "no"
20
+ elif count == 1: return "one"
21
+ elif count == 2: return "two"
22
+ elif count < 5: return "a few"
23
+ elif count < 10: return "several"
24
+ else: return "many"
25
+
26
+ def get_tile_descriptors(tileset):
27
+ """Creates a mapping from tile character to its list of descriptors."""
28
+ result = {char: set(attrs) for char, attrs in tileset["tiles"].items()}
29
+ # Fake tiles. Should these contain anything? Note that code elsewhere expects everything to be passable or solid
30
+ result["!"] = {"passable"}
31
+ result["*"] = {"passable"}
32
+ return result
33
+
34
+ def analyze_floor(scene, id_to_char, tile_descriptors, describe_absence):
35
+ """Analyzes the last row of the 32x32 scene and generates a floor description."""
36
+ WIDTH = len(scene[0])
37
+ last_row = scene[-1] # The FLOOR row of the scene
38
+ solid_count = sum(
39
+ 1 for tile in last_row
40
+ if tile in id_to_char and (
41
+ "solid" in tile_descriptors.get(id_to_char[tile], []) or
42
+ "diggable" in tile_descriptors.get(id_to_char[tile], [])
43
+ )
44
+ )
45
+ passable_count = sum(
46
+ 1 for tile in last_row if "passable" in tile_descriptors.get(id_to_char[tile], [])
47
+ )
48
+
49
+ if solid_count == WIDTH:
50
+ return "full floor"
51
+ elif passable_count == WIDTH:
52
+ if describe_absence:
53
+ return "no floor"
54
+ else:
55
+ return ""
56
+ elif solid_count > passable_count:
57
+ # Count contiguous groups of passable tiles
58
+ gaps = 0
59
+ in_gap = False
60
+ for tile in last_row:
61
+ # Enemies are also a gap since they immediately fall into the gap
62
+ if "passable" in tile_descriptors.get(id_to_char[tile], []) or "enemy" in tile_descriptors.get(id_to_char[tile], []):
63
+ if not in_gap:
64
+ gaps += 1
65
+ in_gap = True
66
+ elif "solid" in tile_descriptors.get(id_to_char[tile], []):
67
+ in_gap = False
68
+ else:
69
+ print("error")
70
+ print(tile)
71
+ print(id_to_char[tile])
72
+ print(tile_descriptors)
73
+ print(tile_descriptors.get(id_to_char[tile], []))
74
+ raise ValueError("Every tile should be passable, solid, or enemy")
75
+ return f"floor with {describe_quantity(gaps) if coarse_counts else gaps} gap" + ("s" if pluralize and gaps != 1 else "")
76
+ else:
77
+ # Count contiguous groups of solid tiles
78
+ chunks = 0
79
+ in_chunk = False
80
+ for tile in last_row:
81
+ if "solid" in tile_descriptors.get(id_to_char[tile], []):
82
+ if not in_chunk:
83
+ chunks += 1
84
+ in_chunk = True
85
+ elif "passable" in tile_descriptors.get(id_to_char[tile], []) or "enemy" in tile_descriptors.get(id_to_char[tile], []):
86
+ in_chunk = False
87
+ else:
88
+ print("error")
89
+ print(tile)
90
+ print(tile_descriptors)
91
+ print(tile_descriptors.get(tile, []))
92
+ raise ValueError("Every tile should be either passable or solid")
93
+ return f"giant gap with {describe_quantity(chunks) if coarse_counts else chunks} chunk"+("s" if pluralize and chunks != 1 else "")+" of floor"
94
+
95
+ def count_in_scene(scene, tiles, exclude=set()):
96
+ """ counts standalone tiles, unless they are in the exclude set """
97
+ count = 0
98
+ for r, row in enumerate(scene):
99
+ for c, t in enumerate(row):
100
+ #if exclude and t in tiles: print(r,c, exclude)
101
+ if (r,c) not in exclude and t in tiles:
102
+ #if exclude: print((r,t), exclude, (r,t) in exclude)
103
+ count += 1
104
+ #if exclude: print(tiles, exclude, count)
105
+ return count
106
+
107
+ def count_caption_phrase(scene, tiles, name, names, offset = 0, describe_absence=False, exclude=set()):
108
+ """ offset modifies count used in caption """
109
+ count = offset + count_in_scene(scene, tiles, exclude)
110
+ #if name == "loose block": print("count", count)
111
+ if count > 0:
112
+ return f" {describe_quantity(count) if coarse_counts else count} " + (names if pluralize and count > 1 else name) + "."
113
+ elif describe_absence:
114
+ return f" no {names}."
115
+ else:
116
+ return ""
117
+
118
+ def in_column(scene, x, tile):
119
+ for row in scene:
120
+ if row[x] == tile:
121
+ return True
122
+
123
+ return False
124
+
125
+ def analyze_ceiling(scene, id_to_char, tile_descriptors, describe_absence, ceiling_row = 1):
126
+ """
127
+ Analyzes ceiling row (0-based index) to detect a ceiling.
128
+ Returns a caption phrase or an empty string if no ceiling is detected.
129
+ """
130
+ WIDTH = len(scene[0])
131
+
132
+ row = scene[ceiling_row]
133
+ solid_count = sum(1 for tile in row if "solid" in tile_descriptors.get(id_to_char[tile], []))
134
+
135
+ if solid_count == WIDTH:
136
+ return " full ceiling."
137
+ elif solid_count > WIDTH//2:
138
+ # Count contiguous gaps of passable tiles
139
+ gaps = 0
140
+ in_gap = False
141
+ for tile in row:
142
+ # Enemies are also a gap since they immediately fall into the gap, but they are marked as "moving" and not "passable"
143
+ if "passable" in tile_descriptors.get(id_to_char[tile], []) or "moving" in tile_descriptors.get(id_to_char[tile], []):
144
+ if not in_gap:
145
+ gaps += 1
146
+ in_gap = True
147
+ else:
148
+ in_gap = False
149
+ result = f" ceiling with {describe_quantity(gaps) if coarse_counts else gaps} gap" + ("s" if pluralize and gaps != 1 else "") + "."
150
+
151
+ # Adding the "moving" check should make this code unnecessary
152
+ #if result == ' ceiling with no gaps.':
153
+ # print("This should not happen: ceiling with no gaps")
154
+ # print("ceiling_row:", scene[ceiling_row])
155
+ # result = " full ceiling."
156
+
157
+ return result
158
+ elif describe_absence:
159
+ return " no ceiling."
160
+ else:
161
+ return "" # Not enough solid tiles for a ceiling
162
+
163
+ def extract_tileset(tileset_path):
164
+ # Load tileset
165
+ with open(tileset_path, "r") as f:
166
+ tileset = json.load(f)
167
+ #print(f"tileset: {tileset}")
168
+ tile_chars = sorted(tileset['tiles'].keys())
169
+ # Wiggle room for the tileset to be a bit more flexible.
170
+ # However, this requires me to add some bogus tiles to the list.
171
+ # tile_chars.append('!')
172
+ # tile_chars.append('*')
173
+ #print(f"tile_chars: {tile_chars}")
174
+ id_to_char = {idx: char for idx, char in enumerate(tile_chars)}
175
+ #print(f"id_to_char: {id_to_char}")
176
+ char_to_id = {char: idx for idx, char in enumerate(tile_chars)}
177
+ #print(f"char_to_id: {char_to_id}")
178
+ tile_descriptors = get_tile_descriptors(tileset)
179
+ #print(f"tile_descriptors: {tile_descriptors}")
180
+
181
+ return tile_chars, id_to_char, char_to_id, tile_descriptors
182
+
183
+ def flood_fill(scene, visited, start_row, start_col, id_to_char, tile_descriptors, excluded, pipes=False, target_descriptor=None):
184
+ stack = [(start_row, start_col)]
185
+ structure = []
186
+
187
+ while stack:
188
+ row, col = stack.pop()
189
+ if (row, col) in visited or (row, col) in excluded:
190
+ continue
191
+ tile = scene[row][col]
192
+ descriptors = tile_descriptors.get(id_to_char[tile], [])
193
+ # Use target_descriptor if provided, otherwise default to old solid/pipe logic
194
+ if target_descriptor is not None:
195
+ if target_descriptor not in descriptors:
196
+ continue
197
+ else:
198
+ if "solid" not in descriptors or (not pipes and "pipe" in descriptors) or (pipes and "pipe" not in descriptors):
199
+ continue
200
+
201
+ visited.add((row, col))
202
+ structure.append((row, col))
203
+
204
+ # Check neighbors
205
+ for d_row, d_col in [(-1,0), (1,0), (0,-1), (0,1)]:
206
+ # Weird special case for adjacent pipes
207
+ if (id_to_char[tile] == '>' or id_to_char[tile] == ']') and d_col == 1: # if on the right edge of a pipe
208
+ continue # Don't go right if on the right edge of a pipe
209
+ if (id_to_char[tile] == '<' or id_to_char[tile] == '[') and d_col == -1: # if on the left edge of a pipe
210
+ continue # Don't go left if on the left edge of a pipe
211
+
212
+ n_row, n_col = row + d_row, col + d_col
213
+ if 0 <= n_row < len(scene) and 0 <= n_col < len(scene[0]):
214
+ stack.append((n_row, n_col))
215
+
216
+ return structure
mapsheet.png ADDED
models/general_training_helper.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from level_dataset import LevelDataset
3
+ import random
4
+ from util.plotter import Plotter
5
+ from datetime import datetime
6
+ import os
7
+ import threading
8
+ import json
9
+ import torch.nn.functional as F
10
+ import torch
11
+
12
+
13
+
14
+
15
+ def create_dataloaders(json_path, val_json, tokenizer, data_mode, augment, num_tiles,
16
+ negative_prompt_training, block_embeddings, batch_size):
17
+ # Initialize dataset
18
+ train_dataset = LevelDataset(
19
+ json_path=json_path,
20
+ tokenizer=tokenizer,
21
+ shuffle=True,
22
+ mode=data_mode,
23
+ augment=augment,
24
+ num_tiles=num_tiles,
25
+ negative_captions=negative_prompt_training,
26
+ block_embeddings=block_embeddings
27
+ )
28
+ val_dataset = None
29
+ if val_json is not None:
30
+ val_dataset = LevelDataset(
31
+ json_path=val_json,
32
+ tokenizer=tokenizer,
33
+ shuffle=False,
34
+ mode=data_mode,
35
+ augment=False,
36
+ num_tiles=num_tiles,
37
+ negative_captions=negative_prompt_training,
38
+ block_embeddings=block_embeddings
39
+ )
40
+
41
+ # Create dataloader
42
+ train_dataloader = DataLoader(
43
+ train_dataset,
44
+ batch_size=batch_size,
45
+ shuffle=True,
46
+ num_workers=4,
47
+ drop_last=True,
48
+ persistent_workers=True
49
+ )
50
+
51
+ val_dataloader = None
52
+ if val_dataset is not None:
53
+ val_dataloader = DataLoader(
54
+ val_dataset,
55
+ batch_size=batch_size,
56
+ shuffle=False,
57
+ num_workers=4,
58
+ drop_last=False,
59
+ persistent_workers=True
60
+ )
61
+
62
+ return train_dataloader, val_dataloader
63
+
64
+
65
+ def get_random_training_samples(train_dataloader, negative_prompt_training, output_dir = None):
66
+ train_dataset = train_dataloader.dataset
67
+ # Sample four random captions from the dataset
68
+ sample_indices = [random.randint(0, len(train_dataset) - 1) for _ in range(4)]
69
+
70
+ sample_captions = [train_dataset[i][1] for i in sample_indices]
71
+ print("Sample captions:")
72
+ for caption in sample_captions:
73
+ print(caption)
74
+
75
+ sample_negative_captions = ""
76
+ if negative_prompt_training:
77
+ sample_negative_captions = [train_dataset[i][2] for i in sample_indices]
78
+ print("Sample negative captions:")
79
+ for caption in sample_negative_captions:
80
+ print(f" NEG: {caption}")
81
+
82
+ #Write captions to a file
83
+ if output_dir is not None:
84
+ os.makedirs(output_dir, exist_ok=True)
85
+ out_path = os.path.join(output_dir, "sample_captions.txt")
86
+ with open(out_path, "w", encoding="utf-8") as f:
87
+ f.write("Sample captions:\n")
88
+ for caption in sample_captions:
89
+ f.write(str(caption) + "\n")
90
+ if negative_prompt_training:
91
+ f.write("\nSample negative captions:\n")
92
+ for caption in sample_negative_captions:
93
+ f.write(str(caption) + "\n")
94
+ print(f"Sample captions written to {out_path}")
95
+
96
+
97
+ return sample_captions, sample_negative_captions
98
+
99
+
100
+ def start_plotter(log_file, output_dir, left_key, right_key, left_label, right_label, png_name):
101
+ formatted_date = datetime.now().strftime(r'%Y%m%d-%H%M%S')
102
+
103
+ plotter = Plotter(log_file, update_interval=5.0, left_key=left_key, right_key=right_key,
104
+ left_label=left_label, right_label=right_label, output_png=f'{png_name}_{formatted_date}.png')
105
+ plot_thread = threading.Thread(target=plotter.start_plotting)
106
+ plot_thread.daemon = True
107
+ plot_thread.start()
108
+ print(f"{png_name} plotting enabled. Progress will be saved to {os.path.join(output_dir, f'{png_name}_{formatted_date}.png')}")
109
+ return plotter, plot_thread
110
+
111
+
112
+ def kill_plotter(plotter, plot_thread):
113
+ if plot_thread and plot_thread.is_alive():
114
+ plotter.stop_plotting()
115
+ plot_thread.join(timeout=5.0)
116
+ if plot_thread.is_alive():
117
+ print("Warning: Plot thread did not terminate properly")
118
+
119
+
120
+ def load_config_from_json(config_path):
121
+ """Load hyperparameters from a JSON config file."""
122
+ try:
123
+ with open(config_path, 'r') as f:
124
+ config = json.load(f)
125
+ print(f"Configuration loaded from {config_path}")
126
+
127
+ # Print the loaded config for verification
128
+ print("Loaded hyperparameters:")
129
+ for key, value in config.items():
130
+ print(f" {key}: {value}")
131
+
132
+ return config
133
+ except (json.JSONDecodeError, FileNotFoundError) as e:
134
+ print(f"Error loading config file: {e}")
135
+ raise e
136
+
137
+
138
+ def update_args_from_config(args, config):
139
+ """Update argparse namespace with values from config."""
140
+ # Convert config dict to argparse namespace
141
+ for key, value in config.items():
142
+ if hasattr(args, key):
143
+ setattr(args, key, value)
144
+ return args
145
+
146
+
147
+ def get_scene_from_embeddings(image, block_embeddings):
148
+ """Code copied over from level_dataset, should give limited support for block embeddings"""
149
+ # Reshape sample to [batch_size * height * width, embedding_dim]
150
+ batch_size, embedding_dim, height, width = image.shape
151
+
152
+ flat_samples = image.permute(0, 2, 3, 1).reshape(-1, embedding_dim)
153
+
154
+ # Normalize vectors for cosine similarity
155
+ flat_samples = F.normalize(flat_samples, p=2, dim=1).cpu()
156
+ block_embeddings = F.normalize(block_embeddings, p=2, dim=1)
157
+
158
+ # Calculate cosine similarity between each position and all tile embeddings
159
+ similarities = torch.matmul(flat_samples, block_embeddings.t())
160
+
161
+ # Get indices of most similar tiles
162
+ indices = torch.softmax(similarities, dim=1)
163
+
164
+
165
+ # Reshape back to [batch_size, height, width]
166
+ indices = indices.reshape(batch_size, height, width, 13)
167
+ indices = indices.permute(0, 3, 1, 2)
168
+
169
+ image=indices.detach().cpu()
170
+ return image
171
+
172
+
models/latent_diffusion_pipeline.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DDPMPipeline
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Union, List, Tuple
5
+ from diffusers.utils.torch_utils import randn_tensor
6
+ from diffusers.pipelines.ddpm.pipeline_ddpm import ImagePipelineOutput
7
+ import util.common_settings as common_settings
8
+ import os
9
+ import json
10
+ from models.general_training_helper import get_scene_from_embeddings
11
+
12
+ class UnconditionalDDPMPipeline(DDPMPipeline):
13
+ def __init__(self, unet, scheduler, block_embeddings=None):
14
+ super().__init__(unet, scheduler)
15
+
16
+ self.block_embeddings = block_embeddings
17
+
18
+
19
+ def save_pretrained(self, save_directory):
20
+ os.makedirs(save_directory, exist_ok=True)
21
+ super().save_pretrained(save_directory)
22
+ # Save block_embeddings tensor if it exists
23
+ if self.block_embeddings is not None:
24
+ torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))
25
+
26
+ @classmethod
27
+ def from_pretrained(cls, pretrained_model_path, **kwargs):
28
+ pipeline = super().from_pretrained(pretrained_model_path, **kwargs)
29
+ # Load block_embeddings tensor if it exists
30
+ block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
31
+ if os.path.exists(block_embeds_path):
32
+ pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
33
+ else:
34
+ pipeline.block_embeddings = None
35
+ return pipeline
36
+
37
+
38
+
39
+ def give_sprite_scaling_factors(self, sprite_scaling_factors):
40
+ """
41
+ Set the sprite scaling factors for the pipeline.
42
+ This is used to apply per-sprite temperature scaling during inference.
43
+ """
44
+ self.sprite_scaling_factors = sprite_scaling_factors
45
+
46
+ def __call__(
47
+ self,
48
+ batch_size: int = 1,
49
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
50
+ num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,
51
+ output_type: Optional[str] = "tensor",
52
+ return_dict: bool = True,
53
+ height: int = common_settings.MARIO_HEIGHT, width: int = common_settings.MARIO_WIDTH,
54
+ latents: Optional[torch.FloatTensor] = None,
55
+ show_progress_bar=True,
56
+ ) -> Union[ImagePipelineOutput, Tuple]:
57
+
58
+ self.unet.eval()
59
+ with torch.no_grad():
60
+
61
+ if latents is not None:
62
+ image = latents.to(self.device)
63
+ else:
64
+ image_shape = (
65
+ batch_size,
66
+ self.unet.config.in_channels,
67
+ height,
68
+ width
69
+ )
70
+
71
+ image = torch.randn(image_shape, generator=generator, device=self.device)
72
+
73
+ self.scheduler.set_timesteps(num_inference_steps)
74
+
75
+ iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
76
+ for t in iterator:
77
+ #print(image.shape)
78
+ model_output = self.unet(image, t).sample
79
+ image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
80
+
81
+ # Apply per-sprite temperature scaling if enabled
82
+ if hasattr(self,"sprite_scaling_factors") and self.sprite_scaling_factors is not None:
83
+ image = image / self.sprite_scaling_factors.view(1, -1, 1, 1)
84
+
85
+
86
+ if self.block_embeddings is not None:
87
+ image = get_scene_from_embeddings(image, self.block_embeddings)
88
+ else:
89
+ image = F.softmax(image, dim=1)
90
+ image = image.detach().cpu()
91
+
92
+ if not return_dict:
93
+ return (image,)
94
+
95
+ return ImagePipelineOutput(images=image)
96
+
97
+ def print_unet_architecture(self):
98
+ """Prints the architecture of the UNet model."""
99
+ print(self.unet)
models/pipeline_loader.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.text_diffusion_pipeline import TextConditionalDDPMPipeline
2
+ from models.latent_diffusion_pipeline import UnconditionalDDPMPipeline
3
+ import os
4
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
5
+
6
+
7
+ def get_pipeline(model_path):
8
+ # If model_path is a local directory, use the original logic
9
+ if os.path.isdir(model_path):
10
+ #Diffusion models
11
+ if os.path.exists(os.path.join(model_path, "unet")):
12
+ if os.path.exists(os.path.join(model_path, "text_encoder")):
13
+ #If it has a text encoder and a unet, it's text conditional diffusion
14
+ pipe = TextConditionalDDPMPipeline.from_pretrained(model_path)
15
+ else:
16
+ #If it has no text encoder, use the unconditional diffusion model
17
+ pipe = UnconditionalDDPMPipeline.from_pretrained(model_path)
18
+ else:
19
+ # Assume it's a Hugging Face Hub model ID
20
+ # Try to load config to determine if it's text-conditional
21
+ try:
22
+ config, _ = DiffusionPipeline.load_config(model_path)
23
+ components = config.get("components", {})
24
+ except Exception:
25
+ components = {}
26
+ if "text_encoder" in components or "text_encoder" in str(components):
27
+ # Use the local pipeline file for custom_pipeline
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ model_path,
30
+ custom_pipeline="models.text_diffusion_pipeline.TextConditionalDDPMPipeline",
31
+ trust_remote_code=True,
32
+ )
33
+ else:
34
+ # Fallback: try unconditional
35
+ pipe = DiffusionPipeline.from_pretrained(
36
+ model_path,
37
+ custom_pipeline="models.latent_diffusion_pipeline.UnconditionalDDPMPipeline",
38
+ trust_remote_code=True,
39
+ )
40
+
41
+ return pipe
models/sentence_transformers_helper.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ #Mean Pooling - Take average of all tokens
6
+ def mean_pooling(model_output, attention_mask):
7
+ token_embeddings = model_output.last_hidden_state
8
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
9
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
10
+
11
+
12
+ #Encode text
13
+ def encode(texts, tokenizer, model, device='cpu'):
14
+ # Tokenize sentences
15
+ encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
16
+ encoded_input.to(device)
17
+
18
+ # Compute token embeddings
19
+ with torch.no_grad():
20
+ model_output = model(**encoded_input, return_dict=True)
21
+
22
+ # Perform pooling
23
+ embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
24
+
25
+ # Normalize embeddings
26
+ embeddings = F.normalize(embeddings, p=2, dim=1)
27
+
28
+ embeddings = embeddings.to(device)
29
+
30
+ return embeddings
31
+
32
+ # Get embeddings for a batch of captions and optional negative captions
33
+ def get_embeddings(batch_size, tokenizer, model, captions=None, neg_captions=None, device='cpu'):
34
+ embeddings = encode([""]*batch_size, tokenizer, model, device)
35
+
36
+ if captions is not None:
37
+ caption_embeddings = encode(captions, tokenizer, model, device)
38
+ embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
39
+
40
+ if neg_captions is not None:
41
+ neg_embeddings = encode(neg_captions, tokenizer, model, device)
42
+ embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
43
+
44
+
45
+ embeddings = embeddings.unsqueeze(1)
46
+
47
+ return embeddings
48
+
49
+
50
+
51
+
52
+ def get_embeddings_split(batch_size, tokenizer, model, captions=None, neg_captions=None, device='cpu', max_length=20):
53
+
54
+ padding_length = max(max([s.count(".") for s in captions]) if captions else 1,
55
+ max([s.count(".") for s in neg_captions]) if neg_captions else 1)
56
+ if (padding_length>max_length):
57
+ raise ValueError(f"Token sequence length {padding_length} exceeds specified length {max_length}.")
58
+
59
+
60
+ empty_split = split_sentences([""] * batch_size, padding_length)
61
+ embeddings = get_embeddings_from_split(empty_split, tokenizer, model, device)
62
+
63
+ if(captions is not None):
64
+ captions_split = split_sentences(captions, padding_length)
65
+ caption_embeddings = get_embeddings_from_split(captions_split, tokenizer, model, device)
66
+ embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
67
+
68
+ if(neg_captions is not None):
69
+ neg_split = split_sentences(neg_captions, padding_length)
70
+ neg_embeddings = get_embeddings_from_split(neg_split, tokenizer, model, device)
71
+ embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
72
+
73
+
74
+ #We don't need to unsqueeze this, we have an array of (batch_size, padding_length, encoding_size) already
75
+
76
+ return embeddings.to(device)
77
+
78
+
79
+ #This method takes a caption batch in list form, and outputs a 2d list where every caption has been split by period
80
+ def split_sentences(caption_array, padding_length=20):
81
+ split_caption_array = []
82
+
83
+ #Padding happens here
84
+ for caption in caption_array:
85
+ split_caption = [s.strip() for s in caption.split(".") if s.strip()]
86
+ #This is the token padding, we just use an empty string
87
+ split_caption += [""] * (padding_length - len(split_caption))
88
+ split_caption_array.append(split_caption)
89
+
90
+ return split_caption_array
91
+
92
+
93
+ #Expects all split vectors to be the same length
94
+ def get_embeddings_from_split(caption_batch, tokenizer, model, device='cpu'):
95
+ all_caption_encodings = []
96
+ for caption_sequence in caption_batch:
97
+ #Encode the sequence of split captions as if it was a batch, should now be a [maxlength, embeddingsize] tensor
98
+ caption_sequence = encode(caption_sequence, tokenizer, model, device)
99
+
100
+ #We don't reshape this to avoid having to unsqueeze it later
101
+ all_caption_encodings.append(caption_sequence)
102
+
103
+ all_caption_encodings = torch.stack(all_caption_encodings, dim=0)
104
+ return all_caption_encodings
105
+
106
+
107
+
108
+ if __name__ == "__main__":
109
+ cap = split_sentences(["Hello. My name is George. How. Are you doing. Today?", "I am doing. Just fine. Thanks."])
110
+ model_url = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
111
+ device = 'cuda'
112
+ tokenizer = AutoTokenizer.from_pretrained(model_url)
113
+ model = AutoModel.from_pretrained(model_url, trust_remote_code=True).to(device)
114
+ get_embeddings_from_split(cap, tokenizer, model, device)
models/text_diffusion_pipeline.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import NamedTuple, Optional
4
+ import os
5
+ from diffusers import DDPMPipeline, UNet2DConditionModel, DDPMScheduler
6
+ import json
7
+ # Running the main at the end of this requires messing with this import
8
+ from models.text_model import TransformerModel
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from transformers import AutoTokenizer, AutoModel
12
+ import util.common_settings as common_settings
13
+ import models.sentence_transformers_helper as st_helper
14
+ import models.text_model as text_model
15
+ from models.general_training_helper import get_scene_from_embeddings
16
+
17
+ class PipelineOutput(NamedTuple):
18
+ images: torch.Tensor
19
+
20
+
21
+
22
+ # Create a custom pipeline for text-conditional generation
23
+ class TextConditionalDDPMPipeline(DDPMPipeline):
24
+ def __init__(self, unet, scheduler, text_encoder=None, tokenizer=None, supports_pretrained_split=False, block_embeddings=None):
25
+ super().__init__(unet=unet, scheduler=scheduler)
26
+ self.text_encoder = text_encoder
27
+ self.tokenizer = tokenizer
28
+ self.supports_negative_prompt = hasattr(unet, 'negative_prompt_support') and unet.negative_prompt_support
29
+ self.supports_pretrained_split = supports_pretrained_split
30
+ self.block_embeddings = block_embeddings
31
+
32
+ if self.tokenizer is None and self.text_encoder is not None:
33
+ # Use the tokenizer from the text encoder if not provided
34
+ self.tokenizer = self.text_encoder.tokenizer
35
+
36
+ # Register the text_encoder so that .to(), .cpu(), .cuda(), etc. work correctly
37
+ self.register_modules(
38
+ unet=unet,
39
+ scheduler=scheduler,
40
+ text_encoder=self.text_encoder,
41
+ tokenizer=self.tokenizer,
42
+ )
43
+
44
+ # Override the to() method to ensure text_encoder is moved to the correct device
45
+ def to(self, device=None, dtype=None):
46
+ # Call the parent's to() method first
47
+ pipeline = super().to(device, dtype)
48
+
49
+ # Additionally move the text_encoder to the device
50
+ if self.text_encoder is not None:
51
+ self.text_encoder.to(device)
52
+
53
+ return pipeline
54
+
55
+ def save_pretrained(self, save_directory):
56
+ os.makedirs(save_directory, exist_ok=True)
57
+ super().save_pretrained(save_directory) # saves UNet and scheduler
58
+
59
+ # Save block_embeddings tensor if it exists
60
+ if self.block_embeddings is not None:
61
+ torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))
62
+
63
+ # Save supports_negative_prompt and supports_pretrained_split flags
64
+ with open(os.path.join(save_directory, "pipeline_config.json"), "w") as f:
65
+ json.dump({
66
+ "supports_negative_prompt": self.supports_negative_prompt,
67
+ "supports_pretrained_split": self.supports_pretrained_split,
68
+ "text_encoder_type": type(self.text_encoder).__name__
69
+ }, f)
70
+
71
+
72
+ #Text encoder/tokenizer saving is different depending on if we're using a larger pretrained model
73
+ if isinstance(self.text_encoder, TransformerModel):
74
+ # Save custom text encoder
75
+ if self.text_encoder is not None:
76
+ self.text_encoder.save_pretrained(os.path.join(save_directory, "text_encoder"))
77
+ else:
78
+ #Save pretrained tokenizer by name, so we can load from huggingface instead of saving a giant local model
79
+ text_encoder_info = {
80
+ "text_encoder_name": self.text_encoder.config.name_or_path,
81
+ "tokenizer_name": self.tokenizer.name_or_path,
82
+ }
83
+
84
+ text_encoder_directory = os.path.join(save_directory, "text_encoder")
85
+ os.makedirs(text_encoder_directory, exist_ok=True)
86
+
87
+ with open(os.path.join(text_encoder_directory, "loading_info.json"), "w") as f:
88
+ json.dump(text_encoder_info, f)
89
+
90
+
91
+
92
+ @classmethod
93
+ def from_pretrained(cls, pretrained_model_path, **kwargs):
94
+ #from diffusers.utils import load_config, load_state_dict
95
+ # Load model_index.json
96
+ #model_index = load_config(pretrained_model_path)
97
+
98
+ # Load components manually
99
+ unet_path = os.path.join(pretrained_model_path, "unet")
100
+ unet = UNet2DConditionModel.from_pretrained(unet_path)
101
+
102
+ scheduler_path = os.path.join(pretrained_model_path, "scheduler")
103
+ # Have heard that DDIMScheduler might be faster for inference, though not necessarily better
104
+ scheduler = DDPMScheduler.from_pretrained(scheduler_path)
105
+
106
+ tokenizer = None
107
+ text_encoder_path = os.path.join(pretrained_model_path, "text_encoder")
108
+
109
+ if os.path.exists(text_encoder_path):
110
+ #Test for the new saving system, where we save a simple config file
111
+ if os.path.exists(os.path.join(text_encoder_path, "loading_info.json")):
112
+ with open(os.path.join(text_encoder_path, "loading_info.json"), "r") as f:
113
+ encoder_config = json.load(f)
114
+
115
+ text_encoder = AutoModel.from_pretrained(encoder_config['text_encoder_name'], trust_remote_code=True)
116
+ tokenizer = AutoTokenizer.from_pretrained(encoder_config['tokenizer_name'])
117
+
118
+ #Legacy loading system, loads models directly if the whole thing is saved in the directory
119
+ else:
120
+ try:
121
+ text_encoder = AutoModel.from_pretrained(text_encoder_path, local_files_only=True, trust_remote_code=True)
122
+ tokenizer = AutoTokenizer.from_pretrained(text_encoder_path, local_files_only=True)
123
+ except (ValueError, KeyError):
124
+ text_encoder = TransformerModel.from_pretrained(text_encoder_path)
125
+ tokenizer = text_encoder.tokenizer
126
+ else:
127
+ text_encoder = None
128
+
129
+ # Instantiate your pipeline
130
+ pipeline = cls(
131
+ unet=unet,
132
+ scheduler=scheduler,
133
+ text_encoder=text_encoder,
134
+ tokenizer=tokenizer,
135
+ **kwargs,
136
+ )
137
+
138
+ #Loads block embeddings if present
139
+ block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
140
+ if os.path.exists(block_embeds_path):
141
+ pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
142
+ else:
143
+ pipeline.block_embeddings = None
144
+
145
+
146
+ # Load supports_negative_prompt flag if present
147
+ config_path = os.path.join(pretrained_model_path, "pipeline_config.json")
148
+ if os.path.exists(config_path):
149
+ with open(config_path, "r") as f:
150
+ config = json.load(f)
151
+ pipeline.supports_negative_prompt = config.get("supports_negative_prompt", False)
152
+ pipeline.supports_pretrained_split = config.get("supports_pretrained_split", False)
153
+ return pipeline
154
+
155
+ # --- Handle batching for captions ---
156
+ def _prepare_text_batch(self, text: Optional[str | list[str]], batch_size: int, name: str) -> Optional[list[str]]:
157
+ if text is None:
158
+ return None
159
+ if isinstance(text, str):
160
+ return [text] * batch_size
161
+ if isinstance(text, list):
162
+ if len(text) == 1:
163
+ return text * batch_size
164
+ if len(text) != batch_size:
165
+ raise ValueError(f"{name} list length {len(text)} does not match batch_size {batch_size}")
166
+ return text
167
+ raise ValueError(f"{name} must be a string or list of strings")
168
+
169
+ def _prepare_initial_sample(self,
170
+ raw_latent_sample: Optional[torch.Tensor],
171
+ input_scene: Optional[torch.Tensor],
172
+ batch_size: int, height: int, width: int,
173
+ generator: Optional[torch.Generator]) -> torch.Tensor:
174
+ """Prepare the initial sample for diffusion."""
175
+
176
+ sample_shape = (batch_size, self.unet.config.in_channels, height, width)
177
+
178
+ if raw_latent_sample is not None:
179
+ if input_scene is not None:
180
+ raise ValueError("Cannot provide both raw_latent_sample and input_scene")
181
+ sample = raw_latent_sample.to(self.device)
182
+ if sample.shape[1] != sample_shape[1]:
183
+ raise ValueError(f"Wrong number of channels in raw_latent_sample: Expected {self.unet.config.in_channels} but got {sample.shape[1]}")
184
+ if sample.shape[0] == 1 and batch_size > 1:
185
+ sample = sample.repeat(batch_size, 1, 1, 1)
186
+ elif sample.shape[0] != batch_size:
187
+ raise ValueError(f"raw_latent_sample batch size {sample.shape[0]} does not match batch_size {batch_size}")
188
+ elif input_scene is not None:
189
+ # input_scene can be (H, W) or (batch_size, H, W)
190
+ scene_tensor = torch.tensor(input_scene, dtype=torch.long, device=self.device)
191
+ if scene_tensor.dim() == 2:
192
+ # (H, W) -> repeat for batch
193
+ scene_tensor = scene_tensor.unsqueeze(0).repeat(batch_size, 1, 1)
194
+ elif scene_tensor.shape[0] == 1 and batch_size > 1:
195
+ scene_tensor = scene_tensor.repeat(batch_size, 1, 1)
196
+ elif scene_tensor.shape[0] != batch_size:
197
+ raise ValueError(f"input_scene batch size {scene_tensor.shape[0]} does not match batch_size {batch_size}")
198
+ # One-hot encode: (batch, H, W, C)
199
+ one_hot = F.one_hot(scene_tensor, num_classes=self.unet.config.in_channels).float()
200
+ # (batch, H, W, C) -> (batch, C, H, W)
201
+ sample = one_hot.permute(0, 3, 1, 2)
202
+ else:
203
+ # Start from random noise
204
+ sample = torch.randn(sample_shape, generator=generator, device=self.device)
205
+
206
+ return sample
207
+
208
+ def __call__(
209
+ self,
210
+ caption: Optional[str | list[str]] = None,
211
+ negative_prompt: Optional[str | list[str]] = None,
212
+ generator: Optional[torch.Generator] = None,
213
+ num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,
214
+ guidance_scale: float = common_settings.GUIDANCE_SCALE,
215
+ height: int = common_settings.MARIO_HEIGHT,
216
+ width: int = common_settings.MARIO_WIDTH,
217
+ raw_latent_sample: Optional[torch.FloatTensor] = None,
218
+ input_scene: Optional[torch.Tensor] = None,
219
+ output_type: str = "tensor",
220
+ batch_size: int = 1,
221
+ show_progress_bar: bool = True,
222
+ ) -> PipelineOutput:
223
+ """Generate a batch of images based on text input using the diffusion model.
224
+
225
+ Args:
226
+ caption: Text description(s) of the desired output. Can be a string or list of strings.
227
+ negative_prompt: Text description(s) of what should not appear in the output. String or list.
228
+ generator: Random number generator for reproducibility.
229
+ num_inference_steps: Number of denoising steps (more = higher quality, slower).
230
+ guidance_scale: How strongly the generation follows the text prompt (higher = stronger).
231
+ height: Height of generated image in tiles.
232
+ width: Width of generated image in tiles.
233
+ raw_latent_sample: Optional starting point for diffusion instead of random noise.
234
+ Must have correct number of channels matching the UNet.
235
+ input_scene: Optional 2D or 3D int tensor where each value corresponds to a tile type.
236
+ Will be converted to one-hot encoding as starting point.
237
+ output_type: Currently only "tensor" is supported.
238
+ batch_size: Number of samples to generate in parallel.
239
+
240
+ Returns:
241
+ PipelineOutput containing the generated image tensor (batch_size, ...).
242
+ """
243
+
244
+ # I would like to simplify the code to this, but the AI suggestion didn't work, and
245
+ # I did not feel good just pasting it all in. Will need to tackle it bit by bit.
246
+
247
+ # if caption is not None and self.text_encoder is None:
248
+ # raise ValueError("Text encoder required for conditional generation")
249
+
250
+ # self.unet.eval()
251
+ # if self.text_encoder is not None:
252
+ # self.text_encoder.to(self.device)
253
+ # self.text_encoder.eval()
254
+ #
255
+ # with torch.no_grad():
256
+ # # Process text inputs
257
+ # captions = self.prepare_text_batch(caption, batch_size, "caption")
258
+ # negatives = self.prepare_text_batch(negative_prompt, batch_size, "negative_prompt")
259
+
260
+ # # Get embeddings
261
+ # text_embeddings = self.prepare_embeddings(captions, negatives, batch_size)
262
+ #
263
+ # # Set up initial latent state
264
+ # sample = self.prepare_initial_sample(raw_latent_sample, input_scene,
265
+ # batch_size, height, width, generator)
266
+
267
+ # # Run diffusion process
268
+ # sample = self.run_diffusion(sample, text_embeddings, num_inference_steps,
269
+ # guidance_scale, generator, show_progress_bar,
270
+ # has_caption=caption is not None,
271
+ # has_negative=negative_prompt is not None)
272
+
273
+ # # Format output
274
+ # if output_type == "tensor":
275
+ # sample = F.softmax(sample, dim=1)
276
+ # else:
277
+ # raise ValueError(f"Unsupported output type: {output_type}")
278
+
279
+ # return PipelineOutput(images=sample)
280
+
281
+ # Validate text encoder if we need it
282
+ if caption is not None and self.text_encoder is None:
283
+ raise ValueError("Text encoder is required for conditional generation")
284
+
285
+ self.unet.eval()
286
+ if self.text_encoder is not None:
287
+ self.text_encoder.to(self.device)
288
+ self.text_encoder.eval()
289
+
290
+ with torch.no_grad():
291
+ captions = self._prepare_text_batch(caption, batch_size, "caption")
292
+ negatives = self._prepare_text_batch(negative_prompt, batch_size, "negative_prompt")
293
+
294
+ # --- Prepare text embeddings ---
295
+ if(isinstance(self.text_encoder, TransformerModel)):
296
+ text_embeddings = text_model.get_embeddings(batch_size=batch_size,
297
+ tokenizer=self.text_encoder.tokenizer,
298
+ text_encoder=self.text_encoder,
299
+ captions=captions,
300
+ neg_captions=negatives,
301
+ device=self.device)
302
+ else: #Case for the pre-trained text encoder
303
+ if(self.supports_pretrained_split): #If we have a split flag incorporated
304
+ text_embeddings = st_helper.get_embeddings_split(batch_size = batch_size,
305
+ tokenizer=self.tokenizer,
306
+ model=self.text_encoder,
307
+ captions=captions,
308
+ neg_captions=negatives,
309
+ device=self.device)
310
+ else:
311
+ text_embeddings = st_helper.get_embeddings(batch_size = batch_size,
312
+ tokenizer=self.tokenizer,
313
+ model=self.text_encoder,
314
+ captions=captions,
315
+ neg_captions=negatives,
316
+ device=self.device)
317
+
318
+
319
+ # --- Set up initial latent state ---
320
+ sample = self._prepare_initial_sample(raw_latent_sample, input_scene,
321
+ batch_size, height, width, generator)
322
+
323
+ # --- Set up diffusion process ---
324
+ self.scheduler.set_timesteps(num_inference_steps)
325
+
326
+ # Denoising loop
327
+ iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
328
+ for t in iterator:
329
+ # Handle conditional generation
330
+ if captions is not None:
331
+ if negatives is not None:
332
+ # Three copies for negative prompt guidance
333
+ model_input = torch.cat([sample, sample, sample], dim=0)
334
+ else:
335
+ # Two copies for standard classifier-free guidance
336
+ model_input = torch.cat([sample, sample], dim=0)
337
+ else:
338
+ model_input = sample
339
+
340
+ # Predict noise residual
341
+ model_kwargs = {"encoder_hidden_states": text_embeddings}
342
+ noise_pred = self.unet(model_input, t, **model_kwargs).sample
343
+
344
+ # Apply guidance
345
+ if captions is not None:
346
+ if negatives is not None:
347
+ # Split predictions for negative, unconditional, and text-conditional
348
+ noise_pred_neg, noise_pred_uncond, noise_pred_text = noise_pred.chunk(3)
349
+ noise_pred_guided = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
350
+ noise_pred = noise_pred_guided - guidance_scale * (noise_pred_neg - noise_pred_uncond)
351
+ else:
352
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
353
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
354
+
355
+ # Compute previous sample: x_{t-1} = scheduler(x_t, noise_pred)
356
+ sample = self.scheduler.step(noise_pred, t, sample, generator=generator).prev_sample
357
+
358
+ # Convert to output format
359
+ if output_type == "tensor":
360
+ if self.block_embeddings is not None:
361
+ sample = get_scene_from_embeddings(sample, self.block_embeddings)
362
+ else:
363
+ # Apply softmax to get probabilities for each tile type
364
+ sample = F.softmax(sample, dim=1)
365
+ sample = sample.detach().cpu()
366
+ else:
367
+ raise ValueError(f"Unsupported output type: {output_type}")
368
+
369
+ return PipelineOutput(images=sample)
370
+
371
+ def print_unet_architecture(self):
372
+ """Prints the architecture of the UNet model."""
373
+ print(self.unet)
374
+
375
+ def print_text_encoder_architecture(self):
376
+ """Prints the architecture of the text encoder model, if it exists."""
377
+ if self.text_encoder is not None:
378
+ print(self.text_encoder)
379
+ else:
380
+ print("No text encoder is set.")
381
+
382
+ def save_unet_architecture_pdf(self, height, width, filename="unet_architecture", batch_size=1, device=None):
383
+ """
384
+ Have to separately install torchview for this to work
385
+
386
+ Saves a visualization of the UNet architecture as a PDF using torchview.
387
+ Args:
388
+ height: Height of the dummy input.
389
+ width: Width of the dummy input.
390
+ filename: Output PDF filename.
391
+ batch_size: Batch size for dummy input.
392
+ device: Device to run the dummy input on (defaults to pipeline device).
393
+ """
394
+ from torchview import draw_graph
395
+ import graphviz
396
+
397
+ if device is None:
398
+ device = self.device if hasattr(self, 'device') else 'cpu'
399
+ in_channels = self.unet.config.in_channels if hasattr(self.unet, 'config') else 1
400
+ sample_shape = tuple([batch_size, in_channels, height, width])
401
+
402
+ dummy_x = torch.randn(size=sample_shape, device=device)
403
+ dummy_t = torch.tensor([0] * batch_size, dtype=torch.long, device=device)
404
+
405
+ # Prepare dummy text embedding (match what your UNet expects)
406
+ if hasattr(self.unet, 'config') and hasattr(self.unet.config, 'cross_attention_dim'):
407
+ cross_attention_dim = self.unet.config.cross_attention_dim
408
+ else:
409
+ cross_attention_dim = 128 # fallback
410
+ encoder_hidden_states = torch.randn(batch_size, 1, cross_attention_dim, device=device)
411
+
412
+ self.unet.eval()
413
+ inputs = (dummy_x, dummy_t, encoder_hidden_states)
414
+ #self.unet.down_blocks = self.unet.down_blocks[:2]
415
+
416
+ graph = draw_graph(
417
+ model=self.unet,
418
+ input_data=inputs,
419
+ expand_nested=False,
420
+ #enable_output_shape=True,
421
+ #roll_out="nested",
422
+ depth=1
423
+ )
424
+ #graph.visual_graph.engine = "neato"
425
+ graph.visual_graph.attr(#rankdir="LR",
426
+ nodesep="0.1", # decrease space between nodes in the same rank (default ~0.25)
427
+ ranksep="0.2", # decrease space between ranks (default ~0.5)
428
+ concentrate="true" # merge edges between nodes in the same rank
429
+ )
430
+ graph.visual_graph.node_attr.update(
431
+ shape="rectangle",
432
+ width="1.5", # narrow width
433
+ height="0.5" # taller height to make vertical rectangles
434
+ #fixedsize="true"
435
+ )
436
+
437
+ graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
438
+ graph.visual_graph.save('unet_architecture.dot')
439
+
440
+ # Save the graph to a PDF file
441
+ print(f"UNet architecture saved to {filename}")
442
+
models/text_model.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from xml.parsers.expat import model
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+ import os
7
+ import json
8
+ from safetensors.torch import save_file, load_file
9
+ from tokenizer import Tokenizer
10
+
11
+ def get_embeddings(batch_size, tokenizer, text_encoder, captions=None, neg_captions=None, device='cpu'):
12
+ max_length = text_encoder.max_seq_length
13
+ empty_ids = encode_token_captions([""] * batch_size, tokenizer, max_length, device=device)
14
+ embeddings = text_encoder.get_embeddings(empty_ids)
15
+
16
+ if(captions is not None):
17
+ caption_ids = encode_token_captions(captions, tokenizer, max_length, device=device)
18
+ caption_embeddings = text_encoder.get_embeddings(caption_ids)
19
+ embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
20
+
21
+ if(neg_captions is not None):
22
+ neg_ids = encode_token_captions(neg_captions, tokenizer, max_length, device=device)
23
+ neg_embeddings = text_encoder.get_embeddings(neg_ids)
24
+ embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
25
+
26
+ return embeddings.to(device)
27
+
28
+ def encode_token_captions(captions, tokenizer, max_length, device='cpu'):
29
+ caption_ids = []
30
+ for caption in captions:
31
+ tokens = tokenizer.encode(caption)
32
+ caption_tokens = tokenizer.pad_sequence(tokens, max_length)
33
+ caption_ids.append(torch.tensor(caption_tokens, dtype=torch.long).unsqueeze(0))
34
+ return torch.cat(caption_ids, dim=0).to(device)
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+ # Transformer model for MLM training
45
+
46
+ class TransformerModel(nn.Module):
47
+ def __init__(self, vocab_size, embedding_dim, hidden_dim, tokenizer=None, num_heads=8, num_layers=4, max_seq_length=100):
48
+ super().__init__()
49
+ self.embedding_dim = embedding_dim
50
+ self.vocab_size = vocab_size
51
+ self.hidden_dim = hidden_dim
52
+ self.num_heads = num_heads
53
+ self.num_layers = num_layers
54
+ self.max_seq_length = max_seq_length
55
+
56
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
57
+ self.positional_encoding = self.create_positional_encoding(max_seq_length, embedding_dim)
58
+
59
+ encoder_layers = nn.TransformerEncoderLayer(
60
+ d_model=embedding_dim,
61
+ nhead=num_heads,
62
+ dim_feedforward=hidden_dim,
63
+ batch_first=True
64
+ )
65
+ self.transformer = nn.TransformerEncoder(encoder_layers, num_layers)
66
+ self.fc = nn.Linear(embedding_dim, vocab_size)
67
+
68
+ self.tokenizer = tokenizer
69
+
70
+ def create_positional_encoding(self, max_seq_length, embedding_dim):
71
+ # The implementation uses a sinusoidal positional encoding, which creates a unique pattern for each position in the sequence.
72
+ # The frequencies create unique values, the sin/cos bounds values
73
+ position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
74
+ # Creates a set of divisors that create different frequencies
75
+ div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
76
+ pe = torch.zeros(max_seq_length, embedding_dim)
77
+ # Even dimensions use sin, odd dimensions use cos
78
+ pe[:, 0::2] = torch.sin(position * div_term)
79
+ pe[:, 1::2] = torch.cos(position * div_term)
80
+ return pe.unsqueeze(0)
81
+
82
+ def get_embeddings(self, x):
83
+ """ This gets the actual latent embedding vectors """
84
+ # Ensure positional encoding is on the same device as input
85
+ pe = self.positional_encoding[:, :x.size(1), :].to(x.device)
86
+ # Embed input and add positional encoding
87
+ embedded = self.embedding(x) + pe
88
+ return self.transformer(embedded)
89
+
90
+ def forward(self, x):
91
+ """ This gets the token within the vocabulary """
92
+ transformer_out = self.get_embeddings(x)
93
+ # Project to vocabulary size
94
+ return self.fc(transformer_out)
95
+
96
+ def save_pretrained(self, save_directory):
97
+ os.makedirs(save_directory, exist_ok=True)
98
+
99
+ config = {
100
+ "vocab_size": self.vocab_size,
101
+ "embedding_dim": self.embedding_dim,
102
+ "hidden_dim": self.hidden_dim,
103
+ "num_heads": self.num_heads,
104
+ "num_layers": self.num_layers,
105
+ "max_seq_length": self.max_seq_length,
106
+ }
107
+ with open(os.path.join(save_directory, "config.json"), "w") as f:
108
+ json.dump(config, f)
109
+
110
+ # Save model weights
111
+ save_file(self.state_dict(), os.path.join(save_directory, "model.safetensors"))
112
+
113
+ # Save tokenizer if present
114
+ if self.tokenizer is not None:
115
+ self.tokenizer.save(os.path.join(save_directory, "tokenizer.pkl"))
116
+
117
+ @classmethod
118
+ def from_pretrained(cls, load_directory):
119
+ with open(os.path.join(load_directory, "config.json")) as f:
120
+ config = json.load(f)
121
+
122
+ model = cls(**config)
123
+
124
+ # Load weights
125
+ state_dict = load_file(os.path.join(load_directory, "model.safetensors"))
126
+ model.load_state_dict(state_dict)
127
+
128
+ # Load tokenizer if available
129
+ tokenizer_path = os.path.join(load_directory, "tokenizer.pkl")
130
+ if os.path.exists(tokenizer_path):
131
+ tokenizer = Tokenizer()
132
+ tokenizer.load(tokenizer_path)
133
+ model.tokenizer = tokenizer
134
+
135
+ return model
136
+
137
+ def print_architecture(self, inputs=None):
138
+ parser = argparse.ArgumentParser()
139
+ parser.add_argument("--model_path", type=str, required=True, help="Path to trained transformer model")
140
+ parser.add_argument("--json", type=str, default="SMB1_LevelsAndCaptions-regular-test.json", help="Path to dataset json file")
141
+ parser.add_argument("--num_samples", type=int, default=10, help="Number of captions to evaluate")
142
+ parser.add_argument("--mask_prob", type=float, default=0.15, help="Probability of masking each token")
143
+
144
+ parser.add_argument("--compare_checkpoints", action="store_true", default=False, help="Run comparison across all model checkpoints")
145
+ args = parser.parse_args()
146
+
147
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148
+ model = TransformerModel.from_pretrained(args.model_path).to(device)
149
+ print(f"Loaded model from {args.model_path}")
150
+
151
+ import os
152
+ import re
153
+ import json
154
+ import matplotlib.pyplot as plt
155
+ from torchview import draw_graph
156
+ import graphviz
157
+
158
+ graph = draw_graph(
159
+ model=model,
160
+ input_data=inputs,
161
+ expand_nested=False,
162
+ #enable_output_shape=True,
163
+ #roll_out="nested",
164
+ depth=1
165
+ )
166
+
167
+ # Save plot
168
+ filename = 'mlm_architecture'
169
+ graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
170
+ #graph.visual_graph.save('unet_architecture.dot')
171
+
172
+ def save_architecture_pdf(self, filename="transformer_architecture.pdf", input_length=32):
173
+ """Save a visualization of the model architecture as a PDF using torchview."""
174
+ try:
175
+ from torchview import draw_graph
176
+ except ImportError:
177
+ raise ImportError("torchview is required for model visualization. Install with 'pip install torchview'.")
178
+ import torch
179
+ import os
180
+ # Create a dummy input of the correct type for the model
181
+ captions = ["full floor. two coins. one pipe.", "floor with two gaps. one cannon. many enemies."]
182
+ tensor = encode_token_captions(captions, self.tokenizer, self.max_seq_length, device=next(self.parameters()).device)
183
+ input_length = tensor.size(1) if tensor.dim() > 1 else self.max_seq_length
184
+
185
+ num_tokens_list = [len(self.tokenizer.encode(c)) for c in captions]
186
+ input_length = max(num_tokens_list) if num_tokens_list else input_length
187
+ dummy_input = torch.zeros((1, input_length), dtype=torch.long, device=next(self.parameters()).device)
188
+
189
+ # Draw the graph and save as PNG
190
+ graph = draw_graph(self, input_data=dummy_input, expand_nested=True, save_graph=True, filename=filename.replace('.pdf',''), directory=".", depth=2)
191
+ png_file = filename.replace('.pdf', '.png')
192
+ # Convert PNG to PDF
193
+ if os.path.exists(png_file):
194
+ try:
195
+ from PIL import Image
196
+ im = Image.open(png_file)
197
+ im.save(filename, "PDF", resolution=100.0)
198
+ print(f"Saved architecture PDF to {filename}")
199
+ # Optionally, remove the PNG file
200
+ os.remove(png_file)
201
+ except ImportError:
202
+ print(f"PIL not installed. Architecture saved as PNG: {png_file}")
203
+ except Exception as e:
204
+ print(f"Could not convert PNG to PDF: {e}")
205
+ else:
206
+ print(f"Could not find PNG file to convert: {png_file}")
util/common_settings.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ NUM_INFERENCE_STEPS = 30
3
+ GUIDANCE_SCALE = 7.5
4
+
5
+ MARIO_HEIGHT = 16
6
+ MARIO_WIDTH = 16
7
+
8
+ MARIO_TILE_PIXEL_DIM = 16
9
+ MARIO_TILE_COUNT = 13
10
+
11
+ LR_HEIGHT = 32
12
+ LR_WIDTH = 32
13
+
14
+ LR_TILE_PIXEL_DIM = 8
15
+ LR_TILE_COUNT = 8
16
+
17
+ MEGAMAN_HEIGHT = 14
18
+ MEGAMAN_WIDTH = 16
util/naming_conventions.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name_map = [
2
+ ("Mar1and2-conditional-regular", "MLM-regular"),
3
+ ("Mar1and2-conditional-absence", "MLM-absence"),
4
+ ("Mar1and2-conditional-negative", "MLM-negative"),
5
+ ("Mar1and2-conditional-MiniLM-regular", "MiniLM-single-regular"),
6
+ ("Mar1and2-conditional-MiniLM-absence", "MiniLM-single-absence"),
7
+ ("Mar1and2-conditional-MiniLM-negative", "MiniLM-single-negative"),
8
+ ("Mar1and2-conditional-MiniLMsplit-regular", "MiniLM-multiple-regular"),
9
+ ("Mar1and2-conditional-MiniLMsplit-absence", "MiniLM-multiple-absence"),
10
+ ("Mar1and2-conditional-MiniLMsplit-negative", "MiniLM-multiple-negative"),
11
+ ("Mar1and2-conditional-GTE-regular", "GTE-single-regular"),
12
+ ("Mar1and2-conditional-GTE-absence", "GTE-single-absence"),
13
+ ("Mar1and2-conditional-GTE-negative", "GTE-single-negative"),
14
+ ("Mar1and2-conditional-GTEsplit-regular", "GTE-multiple-regular"),
15
+ ("Mar1and2-conditional-GTEsplit-absence", "GTE-multiple-absence"),
16
+ ("Mar1and2-conditional-GTEsplit-negative", "GTE-multiple-negative"),
17
+ ("Mar1and2-fdm-MiniLM-regular", "FDM-MiniLM-regular"),
18
+ ("Mar1and2-fdm-MiniLM-absence", "FDM-MiniLM-absence"),
19
+ ("Mar1and2-fdm-GTE-regular", "FDM-GTE-regular"),
20
+ ("Mar1and2-fdm-GTE-absence", "FDM-GTE-absence"),
21
+ ("Mar1and2-wgan", "WGAN"),
22
+ ("Mar1and2-unconditional", "Unconditional"),
23
+ ("MarioGPT_metrics", "MarioGPT"),
24
+ ]
25
+
26
+ def get_model_name_map_and_order():
27
+ mapping = dict(model_name_map)
28
+ order = [v for k, v in model_name_map]
29
+ return mapping, order
util/plotter.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Track changes in loss and learning rate during execution
2
+ import argparse
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ import os
6
+ import time
7
+ import json
8
+ import tempfile
9
+ import shutil
10
+ from pathlib import Path
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(description="Train a text-conditional diffusion model for tile-based level generation")
15
+
16
+ # Dataset args
17
+ parser.add_argument("--log_file", type=str, default=None, help="The the filepath of the file to get the data from")
18
+ parser.add_argument("--left_key", type=str, default=None, help="The key for the left y-axis")
19
+ parser.add_argument("--right_key", type=str, default=None, help="The key for the right y-axis")
20
+ parser.add_argument("--left_label", type=str, default=None, help="The label for the left y-axis")
21
+ parser.add_argument("--right_label", type=str, default=None, help="The label for the right y-axis")
22
+ parser.add_argument("--output_png", type=str, default="output.png", help="The output png file")
23
+ parser.add_argument("--update_interval", type=int, default=1.0, help="The update inteval in epochs")
24
+ parser.add_argument("--start_point", type=int, default=None, help="The start point for the plot")
25
+
26
+ return parser.parse_args()
27
+
28
+
29
+ def main():
30
+ args = parse_args()
31
+
32
+ log_file = args.log_file
33
+ left_key = args.left_key
34
+ right_key = args.right_key
35
+ left_label = args.left_label
36
+ right_label = args.right_label
37
+ output_png = args.output_png
38
+ update_interval = args.update_interval
39
+ start_point = args.start_point
40
+
41
+ general_update_plot(log_file, left_key, right_key, left_label, right_label, output_png, update_interval=update_interval, startPoint=start_point)
42
+
43
+
44
+ def general_update_plot(log_file, left_key, right_key, left_label, right_label, output_png, update_interval=1.0, startPoint=None):
45
+ log_dir = os.path.dirname(log_file)
46
+
47
+ # Create figure here and ensure it's closed
48
+ fig = plt.figure(figsize=(10, 6))
49
+ ax = fig.add_subplot(111)
50
+
51
+ try:
52
+ if os.path.exists(log_file):
53
+ with open(log_file, 'r') as f:
54
+ data = [json.loads(line) for line in f if line.strip()]
55
+
56
+ if not data:
57
+ return
58
+
59
+ if startPoint is not None:
60
+ data = [entry for entry in data if entry.get('epoch', 0) >= startPoint]
61
+
62
+ if not data:
63
+ return
64
+
65
+ epochs = [entry.get('epoch', 0) for entry in data]
66
+ left = [entry.get(left_key, 0) for entry in data]
67
+
68
+ # For right axis (e.g., lr), only include points where right_key exists
69
+ right_points = [(entry.get('epoch', 0), entry.get(right_key))
70
+ for entry in data if right_key in entry]
71
+ if right_points:
72
+ right_epochs, right_values = zip(*right_points)
73
+ else:
74
+ right_epochs, right_values = [], []
75
+
76
+ # Clear axis
77
+ ax.clear()
78
+
79
+ # Plot both metrics on the same axis
80
+ ax.plot(epochs, left, 'b-', label=left_label)
81
+ if right_epochs:
82
+ ax.plot(right_epochs, right_values, 'r-', label=right_label)
83
+
84
+ ax.set_xlabel('Epoch')
85
+ ax.set_ylabel(left_label) # "Loss" as y-axis label
86
+ ax.set_title('Training Progress')
87
+ ax.legend(loc='upper left')
88
+ #Limit x-axis to startPoint if provided
89
+ if startPoint is not None:
90
+ ax.set_xlim(left=startPoint)
91
+ fig.tight_layout()
92
+
93
+ # Use the stored base directory instead of getting it from log_file
94
+ if os.path.isabs(output_png) or os.path.dirname(output_png):
95
+ output_path = output_png
96
+ else:
97
+ output_path = os.path.join(log_dir, output_png)
98
+
99
+ save_figure_safely(fig, output_path)
100
+ finally:
101
+ plt.close(fig) # Ensure figure is closed even if an error occurs
102
+
103
+ def save_figure_safely(fig, output_path):
104
+ """Save figure to a temporary file first, then move it to the final location"""
105
+ output_path = str(Path(output_path)) # Convert to string path
106
+
107
+ # Create temporary file with .png extension
108
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
109
+ tmp_path = tmp_file.name
110
+
111
+ try:
112
+ # Save to temporary file
113
+ fig.savefig(tmp_path)
114
+
115
+ # Create output directory if it doesn't exist
116
+ os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
117
+
118
+ # Try to move the file to final destination
119
+ # If move fails, try to copy and then delete
120
+ try:
121
+ shutil.move(tmp_path, output_path)
122
+ except OSError:
123
+ shutil.copy2(tmp_path, output_path)
124
+ os.unlink(tmp_path)
125
+ except Exception as e:
126
+ # Clean up temporary file if anything goes wrong
127
+ if os.path.exists(tmp_path):
128
+ os.unlink(tmp_path)
129
+ raise e
130
+
131
+ class Plotter:
132
+ def __init__(self, log_file, update_interval=1.0, left_key='loss', right_key='lr',
133
+ left_label='Loss', right_label='Learning Rate', output_png='training_progress.png'):
134
+ self.log_dir = os.path.dirname(log_file)
135
+ self.log_file = log_file
136
+ self.update_interval = update_interval
137
+ self.running = True
138
+ self.output_png = output_png
139
+ self.left_key = left_key
140
+ self.right_key = right_key
141
+ self.left_label = left_label
142
+ self.right_label = right_label
143
+
144
+ matplotlib.use('Agg')
145
+
146
+ def __enter__(self):
147
+ return self
148
+
149
+ def __exit__(self, exc_type, exc_val, exc_tb):
150
+ self.stop_plotting()
151
+
152
+ def __del__(self):
153
+ self.stop_plotting()
154
+
155
+ def update_plot(self):
156
+ general_update_plot(self.log_file, self.left_key, self.right_key,
157
+ self.left_label, self.right_label, self.output_png,
158
+ update_interval=self.update_interval)
159
+
160
+ def start_plotting(self):
161
+ print("Starting plotting in background")
162
+ while self.running:
163
+ self.update_plot()
164
+ time.sleep(self.update_interval)
165
+
166
+ def stop_plotting(self):
167
+ if hasattr(self, 'running'): # Check if already stopped
168
+ self.running = False
169
+ self.update_plot()
170
+ print("Plotting stopped")
171
+
172
+ if __name__ == "__main__":
173
+ main()
util/sampler.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import subprocess
8
+ import tempfile
9
+
10
+ import numpy as np
11
+ import torch
12
+ from PIL.Image import Image
13
+ from tqdm import tqdm
14
+ from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper
15
+
16
+
17
+ from mario_gpt.lm.base import BaseMarioLM
18
+ from mario_gpt.prompter import Prompter
19
+ from mario_gpt.simulator import Simulator
20
+ from mario_gpt.utils import (
21
+ convert_level_to_png,
22
+ load_level,
23
+ save_level,
24
+ trim_level,
25
+ view_level,
26
+ )
27
+
28
+ def scene_to_ascii(scene, id_to_char, shorten: bool = True) -> List[str]:
29
+ """
30
+ Convert JSON scene files from a list of lists of ints
31
+ to a list of ASCII strings using id_to_char mapping.
32
+ If shorten is True, only the last 15 rows are kept.
33
+ Args:
34
+ scene: List[List[int]] - 2D array of tile IDs
35
+ id_to_char: Dict[int, str] - mapping from tile ID to ASCII character
36
+ shorten: bool - If True, will shorten the output to only include the first 15 rows
37
+ so A* Mario (for SNES graphics) to run without glitching
38
+ Returns:
39
+ List[str]: List of strings, each representing a row in ASCII
40
+ """
41
+ if shorten and len(scene) > 15:
42
+ scene = scene[-15:] # Keep only the last 15 rows
43
+ return ["".join(id_to_char[num] for num in row) for row in scene]
44
+
45
+ @dataclass
46
+ class SampleOutput:
47
+ level: Optional[List[str]]
48
+ prompt: Optional[str] = None
49
+ img: Optional[Image] = None
50
+ sample_predictions_str: Optional[List[str]] = None
51
+ sample_predictions_img: Optional[Image] = None
52
+ level_tensor: Optional[torch.Tensor] = None
53
+ sample_predictions_tensor: Optional[torch.Tensor] = None
54
+ # Uses MarioEval graphics for rendering levels when True
55
+ use_snes_graphics: bool = False
56
+
57
+ @classmethod
58
+ def create(
59
+ cls,
60
+ level_tensor: torch.Tensor,
61
+ sample_predictions_tensor: torch.Tensor,
62
+ tokenizer,
63
+ prompter: Optional[Prompter] = None,
64
+ ) -> SampleOutput:
65
+ # batch = 1
66
+ level = None
67
+ img = None
68
+
69
+ try:
70
+ level = view_level(level_tensor, tokenizer)
71
+ img = convert_level_to_png(level)[0]
72
+ except Exception as e:
73
+ print(
74
+ f"Failed to generate string or image representation for full level! Got error {e}"
75
+ )
76
+ level = None
77
+ img = None
78
+ try:
79
+ sample_predictions_str = view_level(sample_predictions_tensor, tokenizer)
80
+ sample_predictions_img = convert_level_to_png(sample_predictions_str)[0]
81
+ except Exception as e:
82
+ print(
83
+ f"Failed to generate string or image representation for sampled predictions! Got error {e}"
84
+ )
85
+ sample_predictions_str = None
86
+ sample_predictions_img = None
87
+
88
+ prompt = None
89
+ if prompter is not None:
90
+ prompt = prompter(level_tensor)[0]
91
+
92
+ return SampleOutput(
93
+ level,
94
+ prompt,
95
+ img,
96
+ sample_predictions_str,
97
+ sample_predictions_img,
98
+ level_tensor,
99
+ sample_predictions_tensor,
100
+ )
101
+
102
+ @classmethod
103
+ def from_level_predictions(
104
+ cls,
105
+ level: torch.Tensor,
106
+ sample_predictions: torch.Tensor,
107
+ tokenizer,
108
+ prompter: Optional[Prompter] = None,
109
+ ) -> Union[SampleOutput, List[SampleOutput]]:
110
+ level_tensor = trim_level(level).squeeze().detach().cpu()
111
+ sample_predictions_tensor = (
112
+ trim_level(sample_predictions).squeeze().detach().cpu()
113
+ )
114
+
115
+ if len(level_tensor.shape) == 1:
116
+ return SampleOutput.create(
117
+ level_tensor, sample_predictions_tensor, tokenizer, prompter
118
+ )
119
+
120
+ out = []
121
+ for _level_tensor, _sample_predictions_tensor in zip(
122
+ level_tensor, sample_predictions_tensor
123
+ ):
124
+ sample_output = SampleOutput.create(
125
+ _level_tensor, _sample_predictions_tensor, tokenizer, prompter
126
+ )
127
+ out.append(sample_output)
128
+ return out
129
+
130
+ def save(self, filename: str) -> str:
131
+ save_level(self.level, filename)
132
+
133
+ @classmethod
134
+ def load(cls, filename: str) -> SampleOutput:
135
+ level = load_level(filename)
136
+ return SampleOutput(level=level)
137
+
138
+ def play(self, game="mario", level_idx=None, dataset_path=None):
139
+ """
140
+ Play the level using the specified game engine.
141
+ game: "mario" (default) or "loderunner"
142
+ """
143
+ if game == "loderunner":
144
+ import tempfile, json
145
+ # Convert self.level (list of strings) to Lode Runner JSON format
146
+ scene = [[c for c in row] for row in self.level]
147
+ lr_json = [{
148
+ "scene": scene,
149
+ "caption": ""
150
+ }]
151
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
152
+ json.dump(lr_json, tmp)
153
+ tmp_path = tmp.name
154
+ import sys, os
155
+ #sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
156
+ from LodeRunner.loderunner import main
157
+ tmp_path = tmp_path if dataset_path is None else dataset_path
158
+ print(f"Playing Lode Runner level interactively -- {tmp_path}!")
159
+ main.play_lr_level(tmp_path, level_index=level_idx if level_idx is not None else 1)
160
+ else:
161
+ if self.use_snes_graphics:
162
+ simulator = CustomSimulator(level=self.level, jar_path="MarioEval.jar")
163
+ else:
164
+ simulator = CustomSimulator(level=self.level, jar_path="NESMarioEval.jar")
165
+ simulator.interactive()
166
+
167
+ def run_astar(self, render=True):
168
+ if self.use_snes_graphics:
169
+ simulator = CustomSimulator(level=self.level, jar_path="MarioEval.jar")
170
+ else:
171
+ simulator = CustomSimulator(level=self.level, jar_path="NESMarioEval.jar")
172
+ return simulator.astar(render)
173
+
174
+ class CustomSimulator:
175
+ """
176
+ The classic Mario simulator used by MarioGPT is generally,
177
+ better, but it doesn't return any information about
178
+ Mario's performance. The main point of this simulator
179
+ is that information about the performance of the agent
180
+ is printed to the console (though I still need a way
181
+ to caption and return that information)
182
+ """
183
+
184
+ def __init__(self, level, jar_path="MarioEval.jar"):
185
+ while len(level) > 15:
186
+ level.pop(0)
187
+ # For some reason, my older A* agent
188
+ # crashes on Mario levels with 16 rows or more
189
+
190
+ self.level = level
191
+ self.jar_path = jar_path
192
+
193
+ def interactive(self):
194
+ t = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
195
+ save_level(self.level, t.name)
196
+ print(f"Playing level interactively -- {t.name}!")
197
+ _ = subprocess.run(
198
+ ["java", "-jar", self.jar_path, "human", t.name, "human"],
199
+ stdout=subprocess.PIPE,
200
+ stderr=subprocess.PIPE,
201
+ )
202
+ t.close()
203
+ os.unlink(t.name)
204
+
205
+ def astar(self, render: bool = True):
206
+ t = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
207
+ save_level(self.level, t.name)
208
+ print(f"Running Astar agent on level! -- {t.name}")
209
+ render_str = "human" if render else "norender"
210
+ result = subprocess.run(
211
+ ["java", "-jar", self.jar_path, "astar", t.name, render_str],
212
+ stdout=subprocess.PIPE,
213
+ stderr=subprocess.PIPE,
214
+ )
215
+ t.close()
216
+ os.unlink(t.name)
217
+ # Combine stdout and stderr, decode to string, and return
218
+ output = result.stdout.decode("utf-8") + result.stderr.decode("utf-8")
219
+ return output
220
+
221
+ def save_level(level: List[str], filename: str):
222
+ concatenated = "\n".join(level)
223
+ with open(filename, "w") as f:
224
+ f.write(concatenated)
225
+ return filename
226
+
227
+ class GPTSampler:
228
+ def __init__(
229
+ self,
230
+ mario_lm: BaseMarioLM,
231
+ temperature: float = 2.0,
232
+ top_k: int = 16,
233
+ context_len: int = 700,
234
+ use_tqdm: bool = False,
235
+ use_argmax: bool = False,
236
+ ):
237
+ self.mario_lm = mario_lm
238
+ self.temperature = temperature
239
+ self.top_k = top_k
240
+ self.context_len = context_len
241
+ self.use_tqdm = use_tqdm
242
+ self.use_argmax = use_argmax
243
+ self.logits_processor = LogitsProcessorList()
244
+ self.logits_warper = LogitsProcessorList(
245
+ [
246
+ TopKLogitsWarper(top_k), # number of characters
247
+ TemperatureLogitsWarper(temperature),
248
+ ]
249
+ )
250
+
251
+ @property
252
+ def device(self) -> torch.device:
253
+ return self.mario_lm.device
254
+
255
+ def step(
256
+ self,
257
+ seed: torch.Tensor,
258
+ encoder_hidden_states: torch.Tensor,
259
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
260
+ with torch.no_grad():
261
+ attention_mask = torch.ones_like(seed).to(seed.device)
262
+ input_ids = seed
263
+ out = self.mario_lm.lm(
264
+ input_ids=input_ids,
265
+ attention_mask=attention_mask,
266
+ encoder_hidden_states=encoder_hidden_states,
267
+ token_type_ids=None,
268
+ )
269
+ logits = out.logits.detach()
270
+ if len(logits.shape) == 2:
271
+ logits = logits.view(1, 1, -1)
272
+ next_token_logits = logits[:, -1, :]
273
+
274
+ if self.use_argmax:
275
+ next_tokens = next_token_logits.argmax(-1)
276
+ else:
277
+ next_token_scores = self.logits_processor(input_ids, next_token_logits)
278
+ next_token_scores = self.logits_warper(input_ids, next_token_scores)
279
+ probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
280
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
281
+ return next_tokens, encoder_hidden_states
282
+
283
+ def sample(
284
+ self,
285
+ seed: Union[Optional[torch.Tensor], Optional[SampleOutput]] = None,
286
+ prompts: Optional[List[str]] = None,
287
+ num_steps: int = 1,
288
+ encoder_hidden_states: torch.Tensor = None,
289
+ return_tensor: bool = False,
290
+ ):
291
+ self.mario_lm.eval()
292
+ context_len = self.context_len - 28
293
+ with torch.no_grad():
294
+ if seed is None:
295
+ seed = self.mario_lm.generate_seed(1, batch_size=len(prompts)).to(
296
+ self.device
297
+ )
298
+ out_tensor = seed.to(self.device)
299
+ elif isinstance(seed, SampleOutput):
300
+ out_tensor = seed.level_tensor.to(self.device).squeeze()
301
+ else:
302
+ out_tensor = seed.to(self.device).squeeze()
303
+ if len(out_tensor.shape) < 2:
304
+ # if we pass in a single seed vector, then we repeat for each prompt
305
+ # Otherwise, we treat inputs as separate seed-prompt pairs
306
+ out_tensor = out_tensor.view(1, -1).repeat(len(prompts), 1)
307
+ if encoder_hidden_states is None:
308
+ if prompts is not None:
309
+ encoder_hidden_states = torch.stack(
310
+ [
311
+ self.mario_lm.prompter.output_hidden(prompt)
312
+ for prompt in prompts
313
+ ]
314
+ )
315
+ else:
316
+ encoder_hidden_states = torch.stack(
317
+ [
318
+ self.mario_lm.prompter(sample_prompt=True)[1]
319
+ for _ in range(seed.shape[0])
320
+ ]
321
+ )
322
+ encoder_hidden_states = encoder_hidden_states.to(
323
+ self.device
324
+ ) # b x 1 x hidden_dim
325
+ encoder_hidden_states = encoder_hidden_states.view(
326
+ out_tensor.shape[0], 1, -1
327
+ )
328
+ if not self.use_tqdm:
329
+ bar = np.arange(num_steps)
330
+ else:
331
+ bar = tqdm(np.arange(num_steps))
332
+ with torch.no_grad():
333
+ for i in bar:
334
+ inp = out_tensor * 1
335
+ if len(out_tensor.shape) > 0 and out_tensor.shape[-1] > context_len:
336
+ diff = inp.shape[-1] % 14 # height of mario level
337
+ ctx = context_len + diff
338
+ inp = inp[:, -ctx:] * 1
339
+ next_tokens, encoder_hidden_states = self.step(
340
+ inp,
341
+ encoder_hidden_states=encoder_hidden_states,
342
+ )
343
+ out_tensor = torch.cat(
344
+ [out_tensor, next_tokens.unsqueeze(-1)], dim=-1
345
+ )
346
+ if self.use_tqdm:
347
+ bar.set_description(
348
+ f"shape: {inp.shape}, {out_tensor.shape} first: {inp[0][0]}, last: {out_tensor[0][-1]}"
349
+ )
350
+ if self.use_tqdm:
351
+ bar.close()
352
+ sample_out = SampleOutput.from_level_predictions(
353
+ out_tensor,
354
+ out_tensor[:, -num_steps:],
355
+ self.mario_lm.tokenizer,
356
+ self.mario_lm.prompter,
357
+ )
358
+ self.mario_lm.train()
359
+ if return_tensor:
360
+ return sample_out, out_tensor
361
+ return sample_out
362
+
363
+ def __call__(self, *args, **kwargs):
364
+ return self.sample(*args, **kwargs)
365
+
366
+
367
+ class BertSampler:
368
+ def __init__(
369
+ self,
370
+ mario_lm: BaseMarioLM,
371
+ temperature: float = 2.0,
372
+ top_k: int = 16,
373
+ context_len: int = 448,
374
+ mask_proportion: float = 0.16,
375
+ ):
376
+ self.mario_lm = mario_lm
377
+ self.temperature = temperature
378
+ self.top_k = top_k
379
+ self.logits_processor = LogitsProcessorList()
380
+ self.logits_warper = LogitsProcessorList(
381
+ [
382
+ TopKLogitsWarper(top_k), # number of characters
383
+ TemperatureLogitsWarper(temperature),
384
+ ]
385
+ )
386
+ self.context_len = context_len
387
+ self.mask_proportion = mask_proportion
388
+ self.mask_portion = int(self.context_len * self.mask_proportion)
389
+ self.mask_portion = self.mask_portion - self.mask_portion % 14 + 14
390
+
391
+ @property
392
+ def device(self) -> torch.device:
393
+ return self.mario_lm.device
394
+
395
+ def get_context(self, input_ids, mask_indices):
396
+ start_idx = mask_indices[0]
397
+ end_idx = mask_indices[-1]
398
+
399
+ if input_ids.shape[-1] <= self.context_len:
400
+ clipped = input_ids.shape[-1] % 14
401
+ input_ids = input_ids[:clipped]
402
+
403
+ portion = (self.context_len - self.mask_portion) / 2
404
+
405
+ remainder = 0
406
+ left = start_idx - portion
407
+ if left < 0:
408
+ remainder = -1 * left
409
+
410
+ right = end_idx + portion + remainder
411
+
412
+ return input_ids[left:right]
413
+
414
+ def sample(
415
+ self,
416
+ seed: Union[torch.Tensor, SampleOutput],
417
+ mask: torch.Tensor,
418
+ return_tensor: bool = False,
419
+ ):
420
+ self.mario_lm.eval()
421
+ mask_indices = mask.nonzero()
422
+ input_ids = seed
423
+ if isinstance(seed, SampleOutput):
424
+ input_ids = seed.level_tensor.to(self.device).squeeze()
425
+
426
+ input_id_list = []
427
+ for i in range(input_ids.shape[0]):
428
+ input_id = input_ids[i]
429
+ mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
430
+ input_id = self.get_context(input_id, mask_index)
431
+ input_id_list.append(input_id)
432
+ input_ids = torch.stack(input_ids, dim=0).to(self.device)
433
+
434
+ attention_mask = torch.ones_like(input_ids).to(seed.device)
435
+
436
+ if len(input_ids.shape) < 2:
437
+ # if we pass in a single seed vector, then we repeat for each prompt
438
+ # Otherwise, we treat inputs as separate seed-prompt pairs
439
+ input_ids = input_ids.view(1, -1)
440
+
441
+ out = self.mario_lm.lm(
442
+ input_ids=input_ids,
443
+ attention_mask=attention_mask,
444
+ token_type_ids=None,
445
+ )
446
+ logits = out.logits.detach()
447
+ if len(logits.shape) == 2:
448
+ logits = logits.view(1, 1, -1)
449
+
450
+ if self.use_argmax:
451
+ tokens = logits.argmax(-1)
452
+ else:
453
+ tokens_scores = self.logits_processor(input_ids, tokens)
454
+ tokens_scores = self.logits_warper(input_ids, tokens_scores)
455
+ probs = torch.nn.functional.softmax(tokens_scores, dim=-1)
456
+ tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
457
+
458
+ out = input_ids.detach()
459
+
460
+ for i in range(input_ids.shape[0]):
461
+ mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
462
+ out[i, mask_index] = tokens[i, mask_index].detach()
463
+
464
+ sample_out = SampleOutput.from_level_predictions(
465
+ out,
466
+ tokens,
467
+ self.mario_lm.tokenizer,
468
+ self.mario_lm.prompter,
469
+ )
470
+ self.mario_lm.train()
471
+ if return_tensor:
472
+ return sample_out, tokens
473
+ return sample_out