Code that is apparently needed by the pipeline
Browse filesProbably more than is needed, but we'll see.
Some files had to be modified from their original repo versions
- captions/caption_match.py +247 -0
- captions/evaluate_caption_order_tolerance.py +288 -0
- captions/util.py +216 -0
- mapsheet.png +0 -0
- models/general_training_helper.py +172 -0
- models/latent_diffusion_pipeline.py +99 -0
- models/pipeline_loader.py +41 -0
- models/sentence_transformers_helper.py +114 -0
- models/text_diffusion_pipeline.py +442 -0
- models/text_model.py +206 -0
- util/common_settings.py +18 -0
- util/naming_conventions.py +29 -0
- util/plotter.py +173 -0
- util/sampler.py +473 -0
captions/caption_match.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from create_ascii_captions import assign_caption
|
| 2 |
+
|
| 3 |
+
# Quantity order for scoring partial matches
|
| 4 |
+
QUANTITY_TERMS = ["one", "two", "a few", "several", "many"]
|
| 5 |
+
|
| 6 |
+
# Topics to compare
|
| 7 |
+
TOPIC_KEYWORDS = [
|
| 8 |
+
#"giant gap", # I think all gaps are subsumed by the floor topic
|
| 9 |
+
"floor", "ceiling",
|
| 10 |
+
"broken pipe", "upside down pipe", "pipe",
|
| 11 |
+
"coin line", "coin",
|
| 12 |
+
"platform", "tower", #"wall",
|
| 13 |
+
"broken cannon", "cannon",
|
| 14 |
+
"ascending staircase", "descending staircase",
|
| 15 |
+
"rectangular",
|
| 16 |
+
"irregular",
|
| 17 |
+
"question block", "loose block",
|
| 18 |
+
"enem" # catch "enemy"/"enemies"
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
# Need list because the order matters
|
| 22 |
+
KEYWORD_TO_NEGATED_PLURAL = [
|
| 23 |
+
(" broken pipe.", ""), # If not the first phrase
|
| 24 |
+
("broken pipe. ", ""), # If the first phrase (after removing all others)
|
| 25 |
+
(" broken cannon.", ""), # If not the first phrase
|
| 26 |
+
("broken cannon. ", ""), # If the first phrase (after removing all others)
|
| 27 |
+
("pipe", "pipes"),
|
| 28 |
+
("cannon", "cannons"),
|
| 29 |
+
("platform", "platforms"),
|
| 30 |
+
("tower", "towers"),
|
| 31 |
+
("staircase", "staircases"),
|
| 32 |
+
("enem", "enemies"),
|
| 33 |
+
("rectangular", "rectangular block clusters"),
|
| 34 |
+
("irregular", "irregular block clusters"),
|
| 35 |
+
("coin line", "coin lines"),
|
| 36 |
+
("coin.", "coins."), # Need period to avoid matching "coin line"
|
| 37 |
+
("question block", "question blocks"),
|
| 38 |
+
("loose block", "loose blocks")
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
BROKEN_TOPICS = 2 # Number of topics that are considered "broken" (e.g., "broken pipe", "broken cannon")
|
| 42 |
+
|
| 43 |
+
# Plural normalization map (irregulars)
|
| 44 |
+
PLURAL_EXCEPTIONS = {
|
| 45 |
+
"enemies": "enemy",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
def normalize_plural(phrase):
|
| 49 |
+
# Normalize known irregular plurals
|
| 50 |
+
for plural, singular in PLURAL_EXCEPTIONS.items():
|
| 51 |
+
phrase = phrase.replace(plural, singular)
|
| 52 |
+
|
| 53 |
+
# Normalize regular plurals (basic "s" endings)
|
| 54 |
+
words = phrase.split()
|
| 55 |
+
normalized_words = []
|
| 56 |
+
for word in words:
|
| 57 |
+
if word.endswith('s') and not word.endswith('ss'): # avoid "class", "boss"
|
| 58 |
+
singular = word[:-1]
|
| 59 |
+
normalized_words.append(singular)
|
| 60 |
+
else:
|
| 61 |
+
normalized_words.append(word)
|
| 62 |
+
return ' '.join(normalized_words)
|
| 63 |
+
|
| 64 |
+
def extract_phrases(caption, debug=False):
|
| 65 |
+
phrases = [phrase.strip() for phrase in caption.split('.') if phrase.strip()]
|
| 66 |
+
topic_to_phrase = {}
|
| 67 |
+
already_matched_phrases = set() # Track phrases that have been matched
|
| 68 |
+
|
| 69 |
+
for topic in TOPIC_KEYWORDS:
|
| 70 |
+
matching_phrases = []
|
| 71 |
+
|
| 72 |
+
for p in phrases:
|
| 73 |
+
# Only consider phrases that haven't been matched to longer topics
|
| 74 |
+
if topic in p and p not in already_matched_phrases:
|
| 75 |
+
matching_phrases.append(p)
|
| 76 |
+
|
| 77 |
+
if matching_phrases:
|
| 78 |
+
# Filter out "no ..." phrases as equivalent to absence
|
| 79 |
+
phrase = matching_phrases[0]
|
| 80 |
+
if phrase.lower().startswith("no "):
|
| 81 |
+
topic_to_phrase[topic] = None
|
| 82 |
+
if debug:
|
| 83 |
+
print(f"[Extract] Topic '{topic}': detected 'no ...', treating as None")
|
| 84 |
+
else:
|
| 85 |
+
topic_to_phrase[topic] = phrase
|
| 86 |
+
already_matched_phrases.add(phrase) # Mark this phrase as matched
|
| 87 |
+
if debug:
|
| 88 |
+
print(f"[Extract] Topic '{topic}': found phrase '{phrase}'")
|
| 89 |
+
else:
|
| 90 |
+
topic_to_phrase[topic] = None
|
| 91 |
+
if debug:
|
| 92 |
+
print(f"[Extract] Topic '{topic}': no phrase found")
|
| 93 |
+
|
| 94 |
+
return topic_to_phrase
|
| 95 |
+
|
| 96 |
+
def quantity_score(phrase1, phrase2, debug=False):
|
| 97 |
+
def find_quantity(phrase):
|
| 98 |
+
for term in QUANTITY_TERMS:
|
| 99 |
+
if term in phrase:
|
| 100 |
+
return term
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
qty1 = find_quantity(phrase1)
|
| 104 |
+
qty2 = find_quantity(phrase2)
|
| 105 |
+
|
| 106 |
+
if debug:
|
| 107 |
+
print(f"[Quantity] Comparing quantities: '{qty1}' vs. '{qty2}'")
|
| 108 |
+
|
| 109 |
+
if qty1 and qty2:
|
| 110 |
+
idx1 = QUANTITY_TERMS.index(qty1)
|
| 111 |
+
idx2 = QUANTITY_TERMS.index(qty2)
|
| 112 |
+
diff = abs(idx1 - idx2)
|
| 113 |
+
max_diff = len(QUANTITY_TERMS) - 1
|
| 114 |
+
score = 1.0 - (diff / max_diff)
|
| 115 |
+
if debug:
|
| 116 |
+
print(f"[Quantity] Quantity indices: {idx1} vs. {idx2}, diff: {diff}, score: {score:.2f}")
|
| 117 |
+
return score
|
| 118 |
+
if debug:
|
| 119 |
+
print("[Quantity] At least one quantity missing, assigning partial score 0.1")
|
| 120 |
+
return 0.1
|
| 121 |
+
|
| 122 |
+
def compare_captions(correct_caption, generated_caption, debug=False, return_matches=False):
|
| 123 |
+
correct_phrases = extract_phrases(correct_caption, debug=debug)
|
| 124 |
+
generated_phrases = extract_phrases(generated_caption, debug=debug)
|
| 125 |
+
|
| 126 |
+
total_score = 0.0
|
| 127 |
+
num_topics = len(TOPIC_KEYWORDS)
|
| 128 |
+
|
| 129 |
+
exact_matches = []
|
| 130 |
+
partial_matches = []
|
| 131 |
+
excess_phrases = []
|
| 132 |
+
|
| 133 |
+
if debug:
|
| 134 |
+
print("\n--- Starting Topic Comparison ---\n")
|
| 135 |
+
|
| 136 |
+
for topic in TOPIC_KEYWORDS:
|
| 137 |
+
correct = correct_phrases[topic]
|
| 138 |
+
generated = generated_phrases[topic]
|
| 139 |
+
|
| 140 |
+
if debug:
|
| 141 |
+
print(f"[Topic: {topic}] Correct: {correct} | Generated: {generated}")
|
| 142 |
+
|
| 143 |
+
if correct is None and generated is None:
|
| 144 |
+
total_score += 1.0
|
| 145 |
+
if debug:
|
| 146 |
+
print(f"[Topic: {topic}] Both None — full score: 1.0\n")
|
| 147 |
+
elif correct is None or generated is None:
|
| 148 |
+
total_score += -1.0
|
| 149 |
+
if generated is not None: # Considered an excess phrase
|
| 150 |
+
excess_phrases.append(generated)
|
| 151 |
+
if debug:
|
| 152 |
+
print(f"[Topic: {topic}] One is None — penalty: -1.0\n")
|
| 153 |
+
else:
|
| 154 |
+
# Normalize pluralization before comparison
|
| 155 |
+
norm_correct = normalize_plural(correct)
|
| 156 |
+
norm_generated = normalize_plural(generated)
|
| 157 |
+
|
| 158 |
+
if debug:
|
| 159 |
+
print(f"[Topic: {topic}] Normalized: Correct: '{norm_correct}' | Generated: '{norm_generated}'")
|
| 160 |
+
|
| 161 |
+
if norm_correct == norm_generated:
|
| 162 |
+
total_score += 1.0
|
| 163 |
+
exact_matches.append(generated)
|
| 164 |
+
if debug:
|
| 165 |
+
print(f"[Topic: {topic}] Exact match — score: 1.0\n")
|
| 166 |
+
elif any(term in norm_correct for term in QUANTITY_TERMS) and any(term in norm_generated for term in QUANTITY_TERMS):
|
| 167 |
+
qty_score = quantity_score(norm_correct, norm_generated, debug=debug)
|
| 168 |
+
total_score += qty_score
|
| 169 |
+
partial_matches.append(generated)
|
| 170 |
+
if debug:
|
| 171 |
+
print(f"[Topic: {topic}] Quantity-based partial score: {qty_score:.2f}\n")
|
| 172 |
+
else:
|
| 173 |
+
total_score += 0.1
|
| 174 |
+
partial_matches.append(generated)
|
| 175 |
+
if debug:
|
| 176 |
+
print(f"[Topic: {topic}] Partial match (topic overlap) — score: 0.1\n")
|
| 177 |
+
|
| 178 |
+
if debug:
|
| 179 |
+
print(f"[Topic: {topic}] Current total score: {total_score:.4f}\n")
|
| 180 |
+
|
| 181 |
+
if debug:
|
| 182 |
+
print("total_score before normalization:", total_score)
|
| 183 |
+
print(f"Number of topics: {num_topics}")
|
| 184 |
+
|
| 185 |
+
final_score = total_score / num_topics
|
| 186 |
+
if debug:
|
| 187 |
+
print(f"--- Final score: {final_score:.4f} ---\n")
|
| 188 |
+
|
| 189 |
+
if return_matches:
|
| 190 |
+
return final_score, exact_matches, partial_matches, excess_phrases
|
| 191 |
+
|
| 192 |
+
return final_score
|
| 193 |
+
|
| 194 |
+
def process_scene_segments(scene, segment_width, prompt, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence, verbose=False):
|
| 195 |
+
"""
|
| 196 |
+
Process a scene by partitioning it into segments, assigning captions, and computing comparison scores.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
scene (list): The scene to process, represented as a 2D list.
|
| 200 |
+
segment_width (int): The width of each segment.
|
| 201 |
+
prompt (str): The prompt to compare captions against.
|
| 202 |
+
id_to_char (dict): Mapping from tile IDs to characters.
|
| 203 |
+
char_to_id (dict): Mapping from characters to tile IDs.
|
| 204 |
+
tile_descriptors (dict): Descriptions of individual tile types.
|
| 205 |
+
describe_locations (bool): Whether to include location descriptions in captions.
|
| 206 |
+
describe_absence (bool): Whether to indicate absence of items in captions.
|
| 207 |
+
verbose (bool): If True, print captions and scores for each segment.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
tuple: A tuple containing the average comparison score, captions for each segment, and scores for each segment.
|
| 211 |
+
"""
|
| 212 |
+
# Partition the scene into segments of the specified width
|
| 213 |
+
segments = [
|
| 214 |
+
[row[i:i+segment_width] for row in scene] # Properly slice each row of the scene
|
| 215 |
+
for i in range(0, len(scene[0]), segment_width)
|
| 216 |
+
]
|
| 217 |
+
|
| 218 |
+
# Assign captions and compute scores for each segment
|
| 219 |
+
segment_scores = []
|
| 220 |
+
segment_captions = []
|
| 221 |
+
for idx, segment in enumerate(segments):
|
| 222 |
+
segment_caption = assign_caption(segment, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence)
|
| 223 |
+
segment_score = compare_captions(prompt, segment_caption)
|
| 224 |
+
segment_scores.append(segment_score)
|
| 225 |
+
segment_captions.append(segment_caption)
|
| 226 |
+
|
| 227 |
+
if verbose:
|
| 228 |
+
print(f"Segment {idx + 1} caption: {segment_caption}")
|
| 229 |
+
print(f"Segment {idx + 1} comparison score: {segment_score}")
|
| 230 |
+
|
| 231 |
+
# Compute the average comparison score
|
| 232 |
+
average_score = sum(segment_scores) / len(segment_scores) if segment_scores else 0
|
| 233 |
+
|
| 234 |
+
if verbose:
|
| 235 |
+
print(f"Average comparison score across all segments: {average_score}")
|
| 236 |
+
|
| 237 |
+
return average_score, segment_captions, segment_scores
|
| 238 |
+
|
| 239 |
+
if __name__ == '__main__':
|
| 240 |
+
|
| 241 |
+
ref = "floor with one gap. two enemies. one platform. one tower."
|
| 242 |
+
gen = "giant gap with one chunk of floor. two enemies. one platform. one tower."
|
| 243 |
+
|
| 244 |
+
score = compare_captions(ref, gen, debug=True)
|
| 245 |
+
print(f"Should be: {ref}")
|
| 246 |
+
print(f" but was: {gen}")
|
| 247 |
+
print(f"Score: {score}")
|
captions/evaluate_caption_order_tolerance.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import itertools
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
import sys, os
|
| 7 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 8 |
+
import util.common_settings as common_settings # adjust import if needed
|
| 9 |
+
from level_dataset import LevelDataset, visualize_samples, colors, mario_tiles # adjust import if needed
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from evaluate_caption_adherence import calculate_caption_score_and_samples # adjust import if needed
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import matplotlib
|
| 14 |
+
import json
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
|
| 21 |
+
from captions.util import extract_tileset
|
| 22 |
+
from models.pipeline_loader import get_pipeline
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parse_args():
|
| 26 |
+
parser = argparse.ArgumentParser(description="Evaluate caption order tolerance for a diffusion model.")
|
| 27 |
+
parser.add_argument("--model_path", type=str, required=True, help="Path to the trained diffusion model")
|
| 28 |
+
parser.add_argument("--caption", type=str, required=False, default=None, help="Caption to evaluate, phrases separated by periods")
|
| 29 |
+
parser.add_argument("--tileset", type=str, help="Path to the tileset JSON file")
|
| 30 |
+
#parser.add_argument("--json", type=str, default="datasets\\Test_for_caption_order_tolerance.json", help="Path to dataset json file")
|
| 31 |
+
#parser.add_argument("--json", type=str, default="datasets\\SMB1_LevelsAndCaptions-regular-test.json", help="Path to dataset json file")
|
| 32 |
+
parser.add_argument("--json", type=str, default="datasets\\Mar1and2_LevelsAndCaptions-regular.json", help="Path to dataset json file")
|
| 33 |
+
#parser.add_argument("--trials", type=int, default=3, help="Number of times to evaluate each caption permutation")
|
| 34 |
+
parser.add_argument("--inference_steps", type=int, default=common_settings.NUM_INFERENCE_STEPS)
|
| 35 |
+
parser.add_argument("--guidance_scale", type=float, default=common_settings.GUIDANCE_SCALE)
|
| 36 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 37 |
+
parser.add_argument("--game", type=str, choices=["Mario", "LR"], default="Mario", help="Game to evaluate (Mario or Lode Runner)")
|
| 38 |
+
parser.add_argument("--describe_absence", action="store_true", default=False, help="Indicate when there are no occurrences of an item or structure")
|
| 39 |
+
parser.add_argument("--save_as_json", action="store_true", help="Save generated levels as JSON")
|
| 40 |
+
parser.add_argument("--output_dir", type=str, default="visualizations", help="Output directory if not comparing checkpoints (subdir of model directory)")
|
| 41 |
+
parser.add_argument("--max_permutations", type=int, default=5, help="Maximum amount of permutations that can be made per caption")
|
| 42 |
+
return parser.parse_args()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def setup_environment(seed):
|
| 46 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 47 |
+
random.seed(seed)
|
| 48 |
+
np.random.seed(seed)
|
| 49 |
+
torch.manual_seed(seed)
|
| 50 |
+
if torch.cuda.is_available():
|
| 51 |
+
torch.cuda.manual_seed_all(seed)
|
| 52 |
+
return device
|
| 53 |
+
|
| 54 |
+
def load_captions_from_json(json_path):
|
| 55 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
| 56 |
+
data = json.load(f)
|
| 57 |
+
# If the JSON is a list of dicts with a "caption" key
|
| 58 |
+
captions = [entry["caption"] for entry in data if "caption" in entry]
|
| 59 |
+
return captions
|
| 60 |
+
|
| 61 |
+
def creation_of_parameters(caption, max_permutations):
|
| 62 |
+
args = parse_args()
|
| 63 |
+
device = setup_environment(args.seed)
|
| 64 |
+
|
| 65 |
+
if args.game == "Mario":
|
| 66 |
+
num_tiles = common_settings.MARIO_TILE_COUNT
|
| 67 |
+
tileset = '..\TheVGLC\Super Mario Bros\smb.json'
|
| 68 |
+
elif args.game == "LR":
|
| 69 |
+
num_tiles = common_settings.LR_TILE_COUNT
|
| 70 |
+
tileset = '..\TheVGLC\Lode Runner\Loderunner.json'
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError(f"Unknown game: {args.game}")
|
| 73 |
+
|
| 74 |
+
# Load pipeline
|
| 75 |
+
pipe = get_pipeline(args.model_path).to(device)
|
| 76 |
+
|
| 77 |
+
# Load tile metadata
|
| 78 |
+
tile_chars, id_to_char, char_to_id, tile_descriptors = extract_tileset(tileset)
|
| 79 |
+
|
| 80 |
+
perm_captions = []
|
| 81 |
+
if isinstance(caption, list):
|
| 82 |
+
# captions is a list of caption strings
|
| 83 |
+
phrases_per_caption = [
|
| 84 |
+
[p.strip() for p in cap.split('.') if p.strip()]
|
| 85 |
+
for cap in caption
|
| 86 |
+
]
|
| 87 |
+
permutations = []
|
| 88 |
+
for phrases in phrases_per_caption:
|
| 89 |
+
perms = list(itertools.permutations(phrases))
|
| 90 |
+
if len(perms) > max_permutations:
|
| 91 |
+
perms = random.sample(perms, max_permutations)
|
| 92 |
+
permutations.append(perms)
|
| 93 |
+
perm_captions = ['.'.join(perm) + '.' for perms in permutations for perm in perms]
|
| 94 |
+
elif isinstance(caption, str):
|
| 95 |
+
# Split caption into phrases and get all permutations
|
| 96 |
+
phrase = [p.strip() for p in caption.split('.') if p.strip()]
|
| 97 |
+
permutations_cap = []
|
| 98 |
+
perms = list(itertools.permutations(phrase))
|
| 99 |
+
if len(perms) > max_permutations:
|
| 100 |
+
perms = random.sample(perms, max_permutations)
|
| 101 |
+
permutations_cap.append(perms)
|
| 102 |
+
|
| 103 |
+
perm_captions = ['.'.join(perm) + '.' for perms in permutations_cap for perm in perms]
|
| 104 |
+
|
| 105 |
+
# Create a list of dicts as expected by LevelDataset
|
| 106 |
+
caption_data = [{"scene": None, "caption": cap} for cap in perm_captions]
|
| 107 |
+
|
| 108 |
+
# Initialize dataset
|
| 109 |
+
dataset = LevelDataset(
|
| 110 |
+
data_as_list=caption_data,
|
| 111 |
+
shuffle=False,
|
| 112 |
+
mode="text",
|
| 113 |
+
augment=False,
|
| 114 |
+
num_tiles=common_settings.MARIO_TILE_COUNT,
|
| 115 |
+
negative_captions=False,
|
| 116 |
+
block_embeddings=None
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Create dataloader
|
| 120 |
+
dataloader = DataLoader(
|
| 121 |
+
dataset,
|
| 122 |
+
batch_size=min(16, len(perm_captions)),
|
| 123 |
+
shuffle=False,
|
| 124 |
+
num_workers=4,
|
| 125 |
+
drop_last=False,
|
| 126 |
+
persistent_workers=True
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
return pipe, device, id_to_char, char_to_id, tile_descriptors, num_tiles, dataloader, perm_captions, caption_data
|
| 131 |
+
|
| 132 |
+
def statistics_of_captions(captions, dataloader, compare_all_scores, pipe=None, device=None, id_to_char=None, char_to_id=None, tile_descriptors=None, num_tiles=None):
|
| 133 |
+
"""
|
| 134 |
+
Calculate statistics of the captions.
|
| 135 |
+
Returns average, standard deviation, minimum, maximum, and median of caption scores.
|
| 136 |
+
"""
|
| 137 |
+
args = parse_args()
|
| 138 |
+
if not captions:
|
| 139 |
+
print("No captions found in the provided JSON file.")
|
| 140 |
+
return
|
| 141 |
+
print(f"\nLoaded {len(captions)} captions from {args.json}")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
avg_score = np.mean(compare_all_scores)
|
| 145 |
+
std_dev_score = np.std(compare_all_scores)
|
| 146 |
+
min_score = np.min(compare_all_scores)
|
| 147 |
+
max_score = np.max(compare_all_scores)
|
| 148 |
+
median_score = np.median(compare_all_scores)
|
| 149 |
+
|
| 150 |
+
print("\n-----Scores for each caption permutation-----")
|
| 151 |
+
for i, score in enumerate(compare_all_scores):
|
| 152 |
+
print(f"Scores for caption {i + 1}:", score)
|
| 153 |
+
|
| 154 |
+
print("\n-----Statistics of captions-----")
|
| 155 |
+
print(f"Average score: {avg_score:.4f}")
|
| 156 |
+
print(f"Standard deviation: {std_dev_score:.4f}")
|
| 157 |
+
print(f"Minimum score: {min_score:.4f}")
|
| 158 |
+
print(f"Maximum score: {max_score:.4f}")
|
| 159 |
+
print(f"Median score: {median_score:.4f}")
|
| 160 |
+
|
| 161 |
+
return compare_all_scores, avg_score, std_dev_score, min_score, max_score, median_score
|
| 162 |
+
|
| 163 |
+
def main():
|
| 164 |
+
args = parse_args()
|
| 165 |
+
if args.caption is None or args.caption == "":
|
| 166 |
+
caption = load_captions_from_json(args.json)
|
| 167 |
+
else:
|
| 168 |
+
caption = args.caption
|
| 169 |
+
#caption = ("many pipes. many coins. , many enemies. many blocks. , many platforms. many question blocks.").split(',')
|
| 170 |
+
|
| 171 |
+
all_scores = []
|
| 172 |
+
all_avg_scores = []
|
| 173 |
+
all_std_dev_scores = []
|
| 174 |
+
all_min_scores = []
|
| 175 |
+
all_max_scores = []
|
| 176 |
+
all_median_scores = []
|
| 177 |
+
all_captions = [item.strip() for s in caption for item in s.split(",")]
|
| 178 |
+
|
| 179 |
+
one_caption = []
|
| 180 |
+
count = 0
|
| 181 |
+
|
| 182 |
+
output_jsonl_path = os.path.join(args.output_dir, "evaluation_caption_order_results.jsonl")
|
| 183 |
+
with open(output_jsonl_path, "w") as f:
|
| 184 |
+
for cap in all_captions:
|
| 185 |
+
one_caption = cap
|
| 186 |
+
|
| 187 |
+
# Initialize dataset
|
| 188 |
+
pipe, device, id_to_char, char_to_id, tile_descriptors, num_tiles, dataloader, perm_caption, caption_data = creation_of_parameters(one_caption, args.max_permutations)
|
| 189 |
+
if not pipe:
|
| 190 |
+
print("Failed to create pipeline.")
|
| 191 |
+
return
|
| 192 |
+
|
| 193 |
+
avg_score, all_samples, all_prompts, compare_all_scores = calculate_caption_score_and_samples(device, pipe, dataloader, args.inference_steps, args.guidance_scale, args.seed, id_to_char, char_to_id, tile_descriptors, args.describe_absence, output=True, height=common_settings.MARIO_HEIGHT, width=common_settings.MARIO_WIDTH)
|
| 194 |
+
scores, avg_score, std_dev_score, min_score, max_score, median_score = statistics_of_captions(perm_caption, dataloader, compare_all_scores, pipe, device, id_to_char, char_to_id, tile_descriptors, num_tiles)
|
| 195 |
+
|
| 196 |
+
if args.save_as_json:
|
| 197 |
+
result_entry = {
|
| 198 |
+
"Caption": one_caption,
|
| 199 |
+
"Average score for all permutations": avg_score,
|
| 200 |
+
"Standard deviation": std_dev_score,
|
| 201 |
+
"Minimum score": min_score,
|
| 202 |
+
"Maximum score": max_score,
|
| 203 |
+
"Median score": median_score
|
| 204 |
+
#"samples": all_samples[i].tolist() if hasattr(all_samples, "__getitem__") else None,
|
| 205 |
+
#"prompt": all_prompts[i] if i < len(all_prompts) else "N/A"
|
| 206 |
+
}
|
| 207 |
+
f.write(json.dumps(result_entry) + "\n")
|
| 208 |
+
|
| 209 |
+
all_avg_scores.append(avg_score)
|
| 210 |
+
|
| 211 |
+
#scores, avg_score, std_dev_score, min_score, max_score, median_score = statistics_of_captions(perm_caption, dataloader, compare_all_scores, pipe, device, id_to_char, char_to_id, tile_descriptors, num_tiles)
|
| 212 |
+
for score in enumerate(scores):
|
| 213 |
+
all_scores.append(score)
|
| 214 |
+
all_std_dev_scores.append(std_dev_score)
|
| 215 |
+
all_min_scores.append(min_score)
|
| 216 |
+
all_max_scores.append(max_score)
|
| 217 |
+
all_median_scores.append(median_score)
|
| 218 |
+
if (count % 10) == 0:
|
| 219 |
+
f.flush() # Ensure each result is written immediately
|
| 220 |
+
os.fsync(f.fileno()) # Ensure file is flushed to disk
|
| 221 |
+
count = count + 1
|
| 222 |
+
|
| 223 |
+
print(f"\nAverage score across all captions: {avg_score:.4f}")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
visualizations_dir = os.path.join(os.path.dirname(__file__), "visualizations")
|
| 228 |
+
if args.caption is not None or "":
|
| 229 |
+
caption_folder = args.caption.replace(" ", "_").replace(".", "_")
|
| 230 |
+
output_directory = os.path.join(visualizations_dir, caption_folder)
|
| 231 |
+
|
| 232 |
+
visualize_samples(
|
| 233 |
+
all_samples,
|
| 234 |
+
output_dir=output_directory,
|
| 235 |
+
prompts=all_prompts[0] if all_prompts else "No prompts available"
|
| 236 |
+
)
|
| 237 |
+
print(f"\nVisualizations saved to: {output_directory}")
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
print("\nAll samples shape:", all_samples.shape)
|
| 241 |
+
print("\nAll prompts:", all_prompts)
|
| 242 |
+
|
| 243 |
+
all_avg_score = np.mean(all_avg_scores)
|
| 244 |
+
all_std_dev_score = np.std(all_std_dev_scores)
|
| 245 |
+
all_min_score = np.min(all_min_scores)
|
| 246 |
+
all_max_score = np.max(all_max_scores)
|
| 247 |
+
all_median_score = np.median(all_median_scores)
|
| 248 |
+
|
| 249 |
+
if args.save_as_json:
|
| 250 |
+
output_jsonl_path = os.path.join(args.output_dir, "evaluation_caption_order_results.jsonl")
|
| 251 |
+
with open(output_jsonl_path, "w") as f:
|
| 252 |
+
if isinstance(caption, list) or (args.caption is None or args.caption == ""):
|
| 253 |
+
# Multiple captions (permuted)
|
| 254 |
+
for i, score in enumerate(all_avg_scores):
|
| 255 |
+
result_entry = {
|
| 256 |
+
"Caption": caption[i] if i < len(caption) else "N/A",
|
| 257 |
+
"Average score for all permutations": score,
|
| 258 |
+
#"samples": all_samples[i].tolist() if hasattr(all_samples, "__getitem__") else None,
|
| 259 |
+
#"prompt": all_prompts[i] if i < len(all_prompts) else "N/A"
|
| 260 |
+
}
|
| 261 |
+
f.write(json.dumps(result_entry) + "\n")
|
| 262 |
+
else:
|
| 263 |
+
# Single caption
|
| 264 |
+
result_entry = {
|
| 265 |
+
"caption": caption,
|
| 266 |
+
"avg_score": avg_score,
|
| 267 |
+
"samples": all_samples.tolist(),
|
| 268 |
+
"prompts": all_prompts
|
| 269 |
+
}
|
| 270 |
+
f.write(json.dumps(result_entry) + "\n")
|
| 271 |
+
|
| 272 |
+
results = {
|
| 273 |
+
|
| 274 |
+
"Scores of all captions": {
|
| 275 |
+
"Scores": all_scores,
|
| 276 |
+
"Number of captions": len(all_scores),
|
| 277 |
+
"Average of all permutations": all_avg_score,
|
| 278 |
+
"Standard deviation of all permutations": all_std_dev_score,
|
| 279 |
+
"Min score of all permutations": all_min_score,
|
| 280 |
+
"Max score of all permutations": all_max_score,
|
| 281 |
+
"Median score of all permutations": all_median_score
|
| 282 |
+
},
|
| 283 |
+
}
|
| 284 |
+
json.dump(results, f, indent=4)
|
| 285 |
+
|
| 286 |
+
print(f"Results saved to {output_jsonl_path}")
|
| 287 |
+
if __name__ == "__main__":
|
| 288 |
+
main()
|
captions/util.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
from collections import Counter
|
| 5 |
+
|
| 6 |
+
# This file contains utility functions for analyzing and describing levels in both Lode Runner and Super Mario Bros.
|
| 7 |
+
|
| 8 |
+
# Could define these via the command line, but for now they are hardcoded
|
| 9 |
+
coarse_locations = True
|
| 10 |
+
coarse_counts = True
|
| 11 |
+
pluralize = True
|
| 12 |
+
give_staircase_lengths = False
|
| 13 |
+
|
| 14 |
+
def describe_size(count):
|
| 15 |
+
if count <= 4: return "small"
|
| 16 |
+
else: return "big"
|
| 17 |
+
|
| 18 |
+
def describe_quantity(count):
|
| 19 |
+
if count == 0: return "no"
|
| 20 |
+
elif count == 1: return "one"
|
| 21 |
+
elif count == 2: return "two"
|
| 22 |
+
elif count < 5: return "a few"
|
| 23 |
+
elif count < 10: return "several"
|
| 24 |
+
else: return "many"
|
| 25 |
+
|
| 26 |
+
def get_tile_descriptors(tileset):
|
| 27 |
+
"""Creates a mapping from tile character to its list of descriptors."""
|
| 28 |
+
result = {char: set(attrs) for char, attrs in tileset["tiles"].items()}
|
| 29 |
+
# Fake tiles. Should these contain anything? Note that code elsewhere expects everything to be passable or solid
|
| 30 |
+
result["!"] = {"passable"}
|
| 31 |
+
result["*"] = {"passable"}
|
| 32 |
+
return result
|
| 33 |
+
|
| 34 |
+
def analyze_floor(scene, id_to_char, tile_descriptors, describe_absence):
|
| 35 |
+
"""Analyzes the last row of the 32x32 scene and generates a floor description."""
|
| 36 |
+
WIDTH = len(scene[0])
|
| 37 |
+
last_row = scene[-1] # The FLOOR row of the scene
|
| 38 |
+
solid_count = sum(
|
| 39 |
+
1 for tile in last_row
|
| 40 |
+
if tile in id_to_char and (
|
| 41 |
+
"solid" in tile_descriptors.get(id_to_char[tile], []) or
|
| 42 |
+
"diggable" in tile_descriptors.get(id_to_char[tile], [])
|
| 43 |
+
)
|
| 44 |
+
)
|
| 45 |
+
passable_count = sum(
|
| 46 |
+
1 for tile in last_row if "passable" in tile_descriptors.get(id_to_char[tile], [])
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
if solid_count == WIDTH:
|
| 50 |
+
return "full floor"
|
| 51 |
+
elif passable_count == WIDTH:
|
| 52 |
+
if describe_absence:
|
| 53 |
+
return "no floor"
|
| 54 |
+
else:
|
| 55 |
+
return ""
|
| 56 |
+
elif solid_count > passable_count:
|
| 57 |
+
# Count contiguous groups of passable tiles
|
| 58 |
+
gaps = 0
|
| 59 |
+
in_gap = False
|
| 60 |
+
for tile in last_row:
|
| 61 |
+
# Enemies are also a gap since they immediately fall into the gap
|
| 62 |
+
if "passable" in tile_descriptors.get(id_to_char[tile], []) or "enemy" in tile_descriptors.get(id_to_char[tile], []):
|
| 63 |
+
if not in_gap:
|
| 64 |
+
gaps += 1
|
| 65 |
+
in_gap = True
|
| 66 |
+
elif "solid" in tile_descriptors.get(id_to_char[tile], []):
|
| 67 |
+
in_gap = False
|
| 68 |
+
else:
|
| 69 |
+
print("error")
|
| 70 |
+
print(tile)
|
| 71 |
+
print(id_to_char[tile])
|
| 72 |
+
print(tile_descriptors)
|
| 73 |
+
print(tile_descriptors.get(id_to_char[tile], []))
|
| 74 |
+
raise ValueError("Every tile should be passable, solid, or enemy")
|
| 75 |
+
return f"floor with {describe_quantity(gaps) if coarse_counts else gaps} gap" + ("s" if pluralize and gaps != 1 else "")
|
| 76 |
+
else:
|
| 77 |
+
# Count contiguous groups of solid tiles
|
| 78 |
+
chunks = 0
|
| 79 |
+
in_chunk = False
|
| 80 |
+
for tile in last_row:
|
| 81 |
+
if "solid" in tile_descriptors.get(id_to_char[tile], []):
|
| 82 |
+
if not in_chunk:
|
| 83 |
+
chunks += 1
|
| 84 |
+
in_chunk = True
|
| 85 |
+
elif "passable" in tile_descriptors.get(id_to_char[tile], []) or "enemy" in tile_descriptors.get(id_to_char[tile], []):
|
| 86 |
+
in_chunk = False
|
| 87 |
+
else:
|
| 88 |
+
print("error")
|
| 89 |
+
print(tile)
|
| 90 |
+
print(tile_descriptors)
|
| 91 |
+
print(tile_descriptors.get(tile, []))
|
| 92 |
+
raise ValueError("Every tile should be either passable or solid")
|
| 93 |
+
return f"giant gap with {describe_quantity(chunks) if coarse_counts else chunks} chunk"+("s" if pluralize and chunks != 1 else "")+" of floor"
|
| 94 |
+
|
| 95 |
+
def count_in_scene(scene, tiles, exclude=set()):
|
| 96 |
+
""" counts standalone tiles, unless they are in the exclude set """
|
| 97 |
+
count = 0
|
| 98 |
+
for r, row in enumerate(scene):
|
| 99 |
+
for c, t in enumerate(row):
|
| 100 |
+
#if exclude and t in tiles: print(r,c, exclude)
|
| 101 |
+
if (r,c) not in exclude and t in tiles:
|
| 102 |
+
#if exclude: print((r,t), exclude, (r,t) in exclude)
|
| 103 |
+
count += 1
|
| 104 |
+
#if exclude: print(tiles, exclude, count)
|
| 105 |
+
return count
|
| 106 |
+
|
| 107 |
+
def count_caption_phrase(scene, tiles, name, names, offset = 0, describe_absence=False, exclude=set()):
|
| 108 |
+
""" offset modifies count used in caption """
|
| 109 |
+
count = offset + count_in_scene(scene, tiles, exclude)
|
| 110 |
+
#if name == "loose block": print("count", count)
|
| 111 |
+
if count > 0:
|
| 112 |
+
return f" {describe_quantity(count) if coarse_counts else count} " + (names if pluralize and count > 1 else name) + "."
|
| 113 |
+
elif describe_absence:
|
| 114 |
+
return f" no {names}."
|
| 115 |
+
else:
|
| 116 |
+
return ""
|
| 117 |
+
|
| 118 |
+
def in_column(scene, x, tile):
|
| 119 |
+
for row in scene:
|
| 120 |
+
if row[x] == tile:
|
| 121 |
+
return True
|
| 122 |
+
|
| 123 |
+
return False
|
| 124 |
+
|
| 125 |
+
def analyze_ceiling(scene, id_to_char, tile_descriptors, describe_absence, ceiling_row = 1):
|
| 126 |
+
"""
|
| 127 |
+
Analyzes ceiling row (0-based index) to detect a ceiling.
|
| 128 |
+
Returns a caption phrase or an empty string if no ceiling is detected.
|
| 129 |
+
"""
|
| 130 |
+
WIDTH = len(scene[0])
|
| 131 |
+
|
| 132 |
+
row = scene[ceiling_row]
|
| 133 |
+
solid_count = sum(1 for tile in row if "solid" in tile_descriptors.get(id_to_char[tile], []))
|
| 134 |
+
|
| 135 |
+
if solid_count == WIDTH:
|
| 136 |
+
return " full ceiling."
|
| 137 |
+
elif solid_count > WIDTH//2:
|
| 138 |
+
# Count contiguous gaps of passable tiles
|
| 139 |
+
gaps = 0
|
| 140 |
+
in_gap = False
|
| 141 |
+
for tile in row:
|
| 142 |
+
# Enemies are also a gap since they immediately fall into the gap, but they are marked as "moving" and not "passable"
|
| 143 |
+
if "passable" in tile_descriptors.get(id_to_char[tile], []) or "moving" in tile_descriptors.get(id_to_char[tile], []):
|
| 144 |
+
if not in_gap:
|
| 145 |
+
gaps += 1
|
| 146 |
+
in_gap = True
|
| 147 |
+
else:
|
| 148 |
+
in_gap = False
|
| 149 |
+
result = f" ceiling with {describe_quantity(gaps) if coarse_counts else gaps} gap" + ("s" if pluralize and gaps != 1 else "") + "."
|
| 150 |
+
|
| 151 |
+
# Adding the "moving" check should make this code unnecessary
|
| 152 |
+
#if result == ' ceiling with no gaps.':
|
| 153 |
+
# print("This should not happen: ceiling with no gaps")
|
| 154 |
+
# print("ceiling_row:", scene[ceiling_row])
|
| 155 |
+
# result = " full ceiling."
|
| 156 |
+
|
| 157 |
+
return result
|
| 158 |
+
elif describe_absence:
|
| 159 |
+
return " no ceiling."
|
| 160 |
+
else:
|
| 161 |
+
return "" # Not enough solid tiles for a ceiling
|
| 162 |
+
|
| 163 |
+
def extract_tileset(tileset_path):
|
| 164 |
+
# Load tileset
|
| 165 |
+
with open(tileset_path, "r") as f:
|
| 166 |
+
tileset = json.load(f)
|
| 167 |
+
#print(f"tileset: {tileset}")
|
| 168 |
+
tile_chars = sorted(tileset['tiles'].keys())
|
| 169 |
+
# Wiggle room for the tileset to be a bit more flexible.
|
| 170 |
+
# However, this requires me to add some bogus tiles to the list.
|
| 171 |
+
# tile_chars.append('!')
|
| 172 |
+
# tile_chars.append('*')
|
| 173 |
+
#print(f"tile_chars: {tile_chars}")
|
| 174 |
+
id_to_char = {idx: char for idx, char in enumerate(tile_chars)}
|
| 175 |
+
#print(f"id_to_char: {id_to_char}")
|
| 176 |
+
char_to_id = {char: idx for idx, char in enumerate(tile_chars)}
|
| 177 |
+
#print(f"char_to_id: {char_to_id}")
|
| 178 |
+
tile_descriptors = get_tile_descriptors(tileset)
|
| 179 |
+
#print(f"tile_descriptors: {tile_descriptors}")
|
| 180 |
+
|
| 181 |
+
return tile_chars, id_to_char, char_to_id, tile_descriptors
|
| 182 |
+
|
| 183 |
+
def flood_fill(scene, visited, start_row, start_col, id_to_char, tile_descriptors, excluded, pipes=False, target_descriptor=None):
|
| 184 |
+
stack = [(start_row, start_col)]
|
| 185 |
+
structure = []
|
| 186 |
+
|
| 187 |
+
while stack:
|
| 188 |
+
row, col = stack.pop()
|
| 189 |
+
if (row, col) in visited or (row, col) in excluded:
|
| 190 |
+
continue
|
| 191 |
+
tile = scene[row][col]
|
| 192 |
+
descriptors = tile_descriptors.get(id_to_char[tile], [])
|
| 193 |
+
# Use target_descriptor if provided, otherwise default to old solid/pipe logic
|
| 194 |
+
if target_descriptor is not None:
|
| 195 |
+
if target_descriptor not in descriptors:
|
| 196 |
+
continue
|
| 197 |
+
else:
|
| 198 |
+
if "solid" not in descriptors or (not pipes and "pipe" in descriptors) or (pipes and "pipe" not in descriptors):
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
visited.add((row, col))
|
| 202 |
+
structure.append((row, col))
|
| 203 |
+
|
| 204 |
+
# Check neighbors
|
| 205 |
+
for d_row, d_col in [(-1,0), (1,0), (0,-1), (0,1)]:
|
| 206 |
+
# Weird special case for adjacent pipes
|
| 207 |
+
if (id_to_char[tile] == '>' or id_to_char[tile] == ']') and d_col == 1: # if on the right edge of a pipe
|
| 208 |
+
continue # Don't go right if on the right edge of a pipe
|
| 209 |
+
if (id_to_char[tile] == '<' or id_to_char[tile] == '[') and d_col == -1: # if on the left edge of a pipe
|
| 210 |
+
continue # Don't go left if on the left edge of a pipe
|
| 211 |
+
|
| 212 |
+
n_row, n_col = row + d_row, col + d_col
|
| 213 |
+
if 0 <= n_row < len(scene) and 0 <= n_col < len(scene[0]):
|
| 214 |
+
stack.append((n_row, n_col))
|
| 215 |
+
|
| 216 |
+
return structure
|
mapsheet.png
ADDED
|
models/general_training_helper.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader
|
| 2 |
+
from level_dataset import LevelDataset
|
| 3 |
+
import random
|
| 4 |
+
from util.plotter import Plotter
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
import os
|
| 7 |
+
import threading
|
| 8 |
+
import json
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def create_dataloaders(json_path, val_json, tokenizer, data_mode, augment, num_tiles,
|
| 16 |
+
negative_prompt_training, block_embeddings, batch_size):
|
| 17 |
+
# Initialize dataset
|
| 18 |
+
train_dataset = LevelDataset(
|
| 19 |
+
json_path=json_path,
|
| 20 |
+
tokenizer=tokenizer,
|
| 21 |
+
shuffle=True,
|
| 22 |
+
mode=data_mode,
|
| 23 |
+
augment=augment,
|
| 24 |
+
num_tiles=num_tiles,
|
| 25 |
+
negative_captions=negative_prompt_training,
|
| 26 |
+
block_embeddings=block_embeddings
|
| 27 |
+
)
|
| 28 |
+
val_dataset = None
|
| 29 |
+
if val_json is not None:
|
| 30 |
+
val_dataset = LevelDataset(
|
| 31 |
+
json_path=val_json,
|
| 32 |
+
tokenizer=tokenizer,
|
| 33 |
+
shuffle=False,
|
| 34 |
+
mode=data_mode,
|
| 35 |
+
augment=False,
|
| 36 |
+
num_tiles=num_tiles,
|
| 37 |
+
negative_captions=negative_prompt_training,
|
| 38 |
+
block_embeddings=block_embeddings
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Create dataloader
|
| 42 |
+
train_dataloader = DataLoader(
|
| 43 |
+
train_dataset,
|
| 44 |
+
batch_size=batch_size,
|
| 45 |
+
shuffle=True,
|
| 46 |
+
num_workers=4,
|
| 47 |
+
drop_last=True,
|
| 48 |
+
persistent_workers=True
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
val_dataloader = None
|
| 52 |
+
if val_dataset is not None:
|
| 53 |
+
val_dataloader = DataLoader(
|
| 54 |
+
val_dataset,
|
| 55 |
+
batch_size=batch_size,
|
| 56 |
+
shuffle=False,
|
| 57 |
+
num_workers=4,
|
| 58 |
+
drop_last=False,
|
| 59 |
+
persistent_workers=True
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return train_dataloader, val_dataloader
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_random_training_samples(train_dataloader, negative_prompt_training, output_dir = None):
|
| 66 |
+
train_dataset = train_dataloader.dataset
|
| 67 |
+
# Sample four random captions from the dataset
|
| 68 |
+
sample_indices = [random.randint(0, len(train_dataset) - 1) for _ in range(4)]
|
| 69 |
+
|
| 70 |
+
sample_captions = [train_dataset[i][1] for i in sample_indices]
|
| 71 |
+
print("Sample captions:")
|
| 72 |
+
for caption in sample_captions:
|
| 73 |
+
print(caption)
|
| 74 |
+
|
| 75 |
+
sample_negative_captions = ""
|
| 76 |
+
if negative_prompt_training:
|
| 77 |
+
sample_negative_captions = [train_dataset[i][2] for i in sample_indices]
|
| 78 |
+
print("Sample negative captions:")
|
| 79 |
+
for caption in sample_negative_captions:
|
| 80 |
+
print(f" NEG: {caption}")
|
| 81 |
+
|
| 82 |
+
#Write captions to a file
|
| 83 |
+
if output_dir is not None:
|
| 84 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 85 |
+
out_path = os.path.join(output_dir, "sample_captions.txt")
|
| 86 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 87 |
+
f.write("Sample captions:\n")
|
| 88 |
+
for caption in sample_captions:
|
| 89 |
+
f.write(str(caption) + "\n")
|
| 90 |
+
if negative_prompt_training:
|
| 91 |
+
f.write("\nSample negative captions:\n")
|
| 92 |
+
for caption in sample_negative_captions:
|
| 93 |
+
f.write(str(caption) + "\n")
|
| 94 |
+
print(f"Sample captions written to {out_path}")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
return sample_captions, sample_negative_captions
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def start_plotter(log_file, output_dir, left_key, right_key, left_label, right_label, png_name):
|
| 101 |
+
formatted_date = datetime.now().strftime(r'%Y%m%d-%H%M%S')
|
| 102 |
+
|
| 103 |
+
plotter = Plotter(log_file, update_interval=5.0, left_key=left_key, right_key=right_key,
|
| 104 |
+
left_label=left_label, right_label=right_label, output_png=f'{png_name}_{formatted_date}.png')
|
| 105 |
+
plot_thread = threading.Thread(target=plotter.start_plotting)
|
| 106 |
+
plot_thread.daemon = True
|
| 107 |
+
plot_thread.start()
|
| 108 |
+
print(f"{png_name} plotting enabled. Progress will be saved to {os.path.join(output_dir, f'{png_name}_{formatted_date}.png')}")
|
| 109 |
+
return plotter, plot_thread
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def kill_plotter(plotter, plot_thread):
|
| 113 |
+
if plot_thread and plot_thread.is_alive():
|
| 114 |
+
plotter.stop_plotting()
|
| 115 |
+
plot_thread.join(timeout=5.0)
|
| 116 |
+
if plot_thread.is_alive():
|
| 117 |
+
print("Warning: Plot thread did not terminate properly")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def load_config_from_json(config_path):
|
| 121 |
+
"""Load hyperparameters from a JSON config file."""
|
| 122 |
+
try:
|
| 123 |
+
with open(config_path, 'r') as f:
|
| 124 |
+
config = json.load(f)
|
| 125 |
+
print(f"Configuration loaded from {config_path}")
|
| 126 |
+
|
| 127 |
+
# Print the loaded config for verification
|
| 128 |
+
print("Loaded hyperparameters:")
|
| 129 |
+
for key, value in config.items():
|
| 130 |
+
print(f" {key}: {value}")
|
| 131 |
+
|
| 132 |
+
return config
|
| 133 |
+
except (json.JSONDecodeError, FileNotFoundError) as e:
|
| 134 |
+
print(f"Error loading config file: {e}")
|
| 135 |
+
raise e
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def update_args_from_config(args, config):
|
| 139 |
+
"""Update argparse namespace with values from config."""
|
| 140 |
+
# Convert config dict to argparse namespace
|
| 141 |
+
for key, value in config.items():
|
| 142 |
+
if hasattr(args, key):
|
| 143 |
+
setattr(args, key, value)
|
| 144 |
+
return args
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def get_scene_from_embeddings(image, block_embeddings):
|
| 148 |
+
"""Code copied over from level_dataset, should give limited support for block embeddings"""
|
| 149 |
+
# Reshape sample to [batch_size * height * width, embedding_dim]
|
| 150 |
+
batch_size, embedding_dim, height, width = image.shape
|
| 151 |
+
|
| 152 |
+
flat_samples = image.permute(0, 2, 3, 1).reshape(-1, embedding_dim)
|
| 153 |
+
|
| 154 |
+
# Normalize vectors for cosine similarity
|
| 155 |
+
flat_samples = F.normalize(flat_samples, p=2, dim=1).cpu()
|
| 156 |
+
block_embeddings = F.normalize(block_embeddings, p=2, dim=1)
|
| 157 |
+
|
| 158 |
+
# Calculate cosine similarity between each position and all tile embeddings
|
| 159 |
+
similarities = torch.matmul(flat_samples, block_embeddings.t())
|
| 160 |
+
|
| 161 |
+
# Get indices of most similar tiles
|
| 162 |
+
indices = torch.softmax(similarities, dim=1)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# Reshape back to [batch_size, height, width]
|
| 166 |
+
indices = indices.reshape(batch_size, height, width, 13)
|
| 167 |
+
indices = indices.permute(0, 3, 1, 2)
|
| 168 |
+
|
| 169 |
+
image=indices.detach().cpu()
|
| 170 |
+
return image
|
| 171 |
+
|
| 172 |
+
|
models/latent_diffusion_pipeline.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers import DDPMPipeline
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Optional, Union, List, Tuple
|
| 5 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 6 |
+
from diffusers.pipelines.ddpm.pipeline_ddpm import ImagePipelineOutput
|
| 7 |
+
import util.common_settings as common_settings
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
from models.general_training_helper import get_scene_from_embeddings
|
| 11 |
+
|
| 12 |
+
class UnconditionalDDPMPipeline(DDPMPipeline):
|
| 13 |
+
def __init__(self, unet, scheduler, block_embeddings=None):
|
| 14 |
+
super().__init__(unet, scheduler)
|
| 15 |
+
|
| 16 |
+
self.block_embeddings = block_embeddings
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def save_pretrained(self, save_directory):
|
| 20 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 21 |
+
super().save_pretrained(save_directory)
|
| 22 |
+
# Save block_embeddings tensor if it exists
|
| 23 |
+
if self.block_embeddings is not None:
|
| 24 |
+
torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def from_pretrained(cls, pretrained_model_path, **kwargs):
|
| 28 |
+
pipeline = super().from_pretrained(pretrained_model_path, **kwargs)
|
| 29 |
+
# Load block_embeddings tensor if it exists
|
| 30 |
+
block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
|
| 31 |
+
if os.path.exists(block_embeds_path):
|
| 32 |
+
pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
|
| 33 |
+
else:
|
| 34 |
+
pipeline.block_embeddings = None
|
| 35 |
+
return pipeline
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def give_sprite_scaling_factors(self, sprite_scaling_factors):
|
| 40 |
+
"""
|
| 41 |
+
Set the sprite scaling factors for the pipeline.
|
| 42 |
+
This is used to apply per-sprite temperature scaling during inference.
|
| 43 |
+
"""
|
| 44 |
+
self.sprite_scaling_factors = sprite_scaling_factors
|
| 45 |
+
|
| 46 |
+
def __call__(
|
| 47 |
+
self,
|
| 48 |
+
batch_size: int = 1,
|
| 49 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 50 |
+
num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,
|
| 51 |
+
output_type: Optional[str] = "tensor",
|
| 52 |
+
return_dict: bool = True,
|
| 53 |
+
height: int = common_settings.MARIO_HEIGHT, width: int = common_settings.MARIO_WIDTH,
|
| 54 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 55 |
+
show_progress_bar=True,
|
| 56 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 57 |
+
|
| 58 |
+
self.unet.eval()
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
|
| 61 |
+
if latents is not None:
|
| 62 |
+
image = latents.to(self.device)
|
| 63 |
+
else:
|
| 64 |
+
image_shape = (
|
| 65 |
+
batch_size,
|
| 66 |
+
self.unet.config.in_channels,
|
| 67 |
+
height,
|
| 68 |
+
width
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
image = torch.randn(image_shape, generator=generator, device=self.device)
|
| 72 |
+
|
| 73 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 74 |
+
|
| 75 |
+
iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
|
| 76 |
+
for t in iterator:
|
| 77 |
+
#print(image.shape)
|
| 78 |
+
model_output = self.unet(image, t).sample
|
| 79 |
+
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
|
| 80 |
+
|
| 81 |
+
# Apply per-sprite temperature scaling if enabled
|
| 82 |
+
if hasattr(self,"sprite_scaling_factors") and self.sprite_scaling_factors is not None:
|
| 83 |
+
image = image / self.sprite_scaling_factors.view(1, -1, 1, 1)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if self.block_embeddings is not None:
|
| 87 |
+
image = get_scene_from_embeddings(image, self.block_embeddings)
|
| 88 |
+
else:
|
| 89 |
+
image = F.softmax(image, dim=1)
|
| 90 |
+
image = image.detach().cpu()
|
| 91 |
+
|
| 92 |
+
if not return_dict:
|
| 93 |
+
return (image,)
|
| 94 |
+
|
| 95 |
+
return ImagePipelineOutput(images=image)
|
| 96 |
+
|
| 97 |
+
def print_unet_architecture(self):
|
| 98 |
+
"""Prints the architecture of the UNet model."""
|
| 99 |
+
print(self.unet)
|
models/pipeline_loader.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.text_diffusion_pipeline import TextConditionalDDPMPipeline
|
| 2 |
+
from models.latent_diffusion_pipeline import UnconditionalDDPMPipeline
|
| 3 |
+
import os
|
| 4 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_pipeline(model_path):
|
| 8 |
+
# If model_path is a local directory, use the original logic
|
| 9 |
+
if os.path.isdir(model_path):
|
| 10 |
+
#Diffusion models
|
| 11 |
+
if os.path.exists(os.path.join(model_path, "unet")):
|
| 12 |
+
if os.path.exists(os.path.join(model_path, "text_encoder")):
|
| 13 |
+
#If it has a text encoder and a unet, it's text conditional diffusion
|
| 14 |
+
pipe = TextConditionalDDPMPipeline.from_pretrained(model_path)
|
| 15 |
+
else:
|
| 16 |
+
#If it has no text encoder, use the unconditional diffusion model
|
| 17 |
+
pipe = UnconditionalDDPMPipeline.from_pretrained(model_path)
|
| 18 |
+
else:
|
| 19 |
+
# Assume it's a Hugging Face Hub model ID
|
| 20 |
+
# Try to load config to determine if it's text-conditional
|
| 21 |
+
try:
|
| 22 |
+
config, _ = DiffusionPipeline.load_config(model_path)
|
| 23 |
+
components = config.get("components", {})
|
| 24 |
+
except Exception:
|
| 25 |
+
components = {}
|
| 26 |
+
if "text_encoder" in components or "text_encoder" in str(components):
|
| 27 |
+
# Use the local pipeline file for custom_pipeline
|
| 28 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 29 |
+
model_path,
|
| 30 |
+
custom_pipeline="models.text_diffusion_pipeline.TextConditionalDDPMPipeline",
|
| 31 |
+
trust_remote_code=True,
|
| 32 |
+
)
|
| 33 |
+
else:
|
| 34 |
+
# Fallback: try unconditional
|
| 35 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 36 |
+
model_path,
|
| 37 |
+
custom_pipeline="models.latent_diffusion_pipeline.UnconditionalDDPMPipeline",
|
| 38 |
+
trust_remote_code=True,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return pipe
|
models/sentence_transformers_helper.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModel
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
#Mean Pooling - Take average of all tokens
|
| 6 |
+
def mean_pooling(model_output, attention_mask):
|
| 7 |
+
token_embeddings = model_output.last_hidden_state
|
| 8 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 9 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
#Encode text
|
| 13 |
+
def encode(texts, tokenizer, model, device='cpu'):
|
| 14 |
+
# Tokenize sentences
|
| 15 |
+
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
|
| 16 |
+
encoded_input.to(device)
|
| 17 |
+
|
| 18 |
+
# Compute token embeddings
|
| 19 |
+
with torch.no_grad():
|
| 20 |
+
model_output = model(**encoded_input, return_dict=True)
|
| 21 |
+
|
| 22 |
+
# Perform pooling
|
| 23 |
+
embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
| 24 |
+
|
| 25 |
+
# Normalize embeddings
|
| 26 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 27 |
+
|
| 28 |
+
embeddings = embeddings.to(device)
|
| 29 |
+
|
| 30 |
+
return embeddings
|
| 31 |
+
|
| 32 |
+
# Get embeddings for a batch of captions and optional negative captions
|
| 33 |
+
def get_embeddings(batch_size, tokenizer, model, captions=None, neg_captions=None, device='cpu'):
|
| 34 |
+
embeddings = encode([""]*batch_size, tokenizer, model, device)
|
| 35 |
+
|
| 36 |
+
if captions is not None:
|
| 37 |
+
caption_embeddings = encode(captions, tokenizer, model, device)
|
| 38 |
+
embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
|
| 39 |
+
|
| 40 |
+
if neg_captions is not None:
|
| 41 |
+
neg_embeddings = encode(neg_captions, tokenizer, model, device)
|
| 42 |
+
embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
embeddings = embeddings.unsqueeze(1)
|
| 46 |
+
|
| 47 |
+
return embeddings
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_embeddings_split(batch_size, tokenizer, model, captions=None, neg_captions=None, device='cpu', max_length=20):
|
| 53 |
+
|
| 54 |
+
padding_length = max(max([s.count(".") for s in captions]) if captions else 1,
|
| 55 |
+
max([s.count(".") for s in neg_captions]) if neg_captions else 1)
|
| 56 |
+
if (padding_length>max_length):
|
| 57 |
+
raise ValueError(f"Token sequence length {padding_length} exceeds specified length {max_length}.")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
empty_split = split_sentences([""] * batch_size, padding_length)
|
| 61 |
+
embeddings = get_embeddings_from_split(empty_split, tokenizer, model, device)
|
| 62 |
+
|
| 63 |
+
if(captions is not None):
|
| 64 |
+
captions_split = split_sentences(captions, padding_length)
|
| 65 |
+
caption_embeddings = get_embeddings_from_split(captions_split, tokenizer, model, device)
|
| 66 |
+
embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
|
| 67 |
+
|
| 68 |
+
if(neg_captions is not None):
|
| 69 |
+
neg_split = split_sentences(neg_captions, padding_length)
|
| 70 |
+
neg_embeddings = get_embeddings_from_split(neg_split, tokenizer, model, device)
|
| 71 |
+
embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
#We don't need to unsqueeze this, we have an array of (batch_size, padding_length, encoding_size) already
|
| 75 |
+
|
| 76 |
+
return embeddings.to(device)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
#This method takes a caption batch in list form, and outputs a 2d list where every caption has been split by period
|
| 80 |
+
def split_sentences(caption_array, padding_length=20):
|
| 81 |
+
split_caption_array = []
|
| 82 |
+
|
| 83 |
+
#Padding happens here
|
| 84 |
+
for caption in caption_array:
|
| 85 |
+
split_caption = [s.strip() for s in caption.split(".") if s.strip()]
|
| 86 |
+
#This is the token padding, we just use an empty string
|
| 87 |
+
split_caption += [""] * (padding_length - len(split_caption))
|
| 88 |
+
split_caption_array.append(split_caption)
|
| 89 |
+
|
| 90 |
+
return split_caption_array
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#Expects all split vectors to be the same length
|
| 94 |
+
def get_embeddings_from_split(caption_batch, tokenizer, model, device='cpu'):
|
| 95 |
+
all_caption_encodings = []
|
| 96 |
+
for caption_sequence in caption_batch:
|
| 97 |
+
#Encode the sequence of split captions as if it was a batch, should now be a [maxlength, embeddingsize] tensor
|
| 98 |
+
caption_sequence = encode(caption_sequence, tokenizer, model, device)
|
| 99 |
+
|
| 100 |
+
#We don't reshape this to avoid having to unsqueeze it later
|
| 101 |
+
all_caption_encodings.append(caption_sequence)
|
| 102 |
+
|
| 103 |
+
all_caption_encodings = torch.stack(all_caption_encodings, dim=0)
|
| 104 |
+
return all_caption_encodings
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
cap = split_sentences(["Hello. My name is George. How. Are you doing. Today?", "I am doing. Just fine. Thanks."])
|
| 110 |
+
model_url = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
|
| 111 |
+
device = 'cuda'
|
| 112 |
+
tokenizer = AutoTokenizer.from_pretrained(model_url)
|
| 113 |
+
model = AutoModel.from_pretrained(model_url, trust_remote_code=True).to(device)
|
| 114 |
+
get_embeddings_from_split(cap, tokenizer, model, device)
|
models/text_diffusion_pipeline.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from typing import NamedTuple, Optional
|
| 4 |
+
import os
|
| 5 |
+
from diffusers import DDPMPipeline, UNet2DConditionModel, DDPMScheduler
|
| 6 |
+
import json
|
| 7 |
+
# Running the main at the end of this requires messing with this import
|
| 8 |
+
from models.text_model import TransformerModel
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from transformers import AutoTokenizer, AutoModel
|
| 12 |
+
import util.common_settings as common_settings
|
| 13 |
+
import models.sentence_transformers_helper as st_helper
|
| 14 |
+
import models.text_model as text_model
|
| 15 |
+
from models.general_training_helper import get_scene_from_embeddings
|
| 16 |
+
|
| 17 |
+
class PipelineOutput(NamedTuple):
|
| 18 |
+
images: torch.Tensor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Create a custom pipeline for text-conditional generation
|
| 23 |
+
class TextConditionalDDPMPipeline(DDPMPipeline):
|
| 24 |
+
def __init__(self, unet, scheduler, text_encoder=None, tokenizer=None, supports_pretrained_split=False, block_embeddings=None):
|
| 25 |
+
super().__init__(unet=unet, scheduler=scheduler)
|
| 26 |
+
self.text_encoder = text_encoder
|
| 27 |
+
self.tokenizer = tokenizer
|
| 28 |
+
self.supports_negative_prompt = hasattr(unet, 'negative_prompt_support') and unet.negative_prompt_support
|
| 29 |
+
self.supports_pretrained_split = supports_pretrained_split
|
| 30 |
+
self.block_embeddings = block_embeddings
|
| 31 |
+
|
| 32 |
+
if self.tokenizer is None and self.text_encoder is not None:
|
| 33 |
+
# Use the tokenizer from the text encoder if not provided
|
| 34 |
+
self.tokenizer = self.text_encoder.tokenizer
|
| 35 |
+
|
| 36 |
+
# Register the text_encoder so that .to(), .cpu(), .cuda(), etc. work correctly
|
| 37 |
+
self.register_modules(
|
| 38 |
+
unet=unet,
|
| 39 |
+
scheduler=scheduler,
|
| 40 |
+
text_encoder=self.text_encoder,
|
| 41 |
+
tokenizer=self.tokenizer,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Override the to() method to ensure text_encoder is moved to the correct device
|
| 45 |
+
def to(self, device=None, dtype=None):
|
| 46 |
+
# Call the parent's to() method first
|
| 47 |
+
pipeline = super().to(device, dtype)
|
| 48 |
+
|
| 49 |
+
# Additionally move the text_encoder to the device
|
| 50 |
+
if self.text_encoder is not None:
|
| 51 |
+
self.text_encoder.to(device)
|
| 52 |
+
|
| 53 |
+
return pipeline
|
| 54 |
+
|
| 55 |
+
def save_pretrained(self, save_directory):
|
| 56 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 57 |
+
super().save_pretrained(save_directory) # saves UNet and scheduler
|
| 58 |
+
|
| 59 |
+
# Save block_embeddings tensor if it exists
|
| 60 |
+
if self.block_embeddings is not None:
|
| 61 |
+
torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))
|
| 62 |
+
|
| 63 |
+
# Save supports_negative_prompt and supports_pretrained_split flags
|
| 64 |
+
with open(os.path.join(save_directory, "pipeline_config.json"), "w") as f:
|
| 65 |
+
json.dump({
|
| 66 |
+
"supports_negative_prompt": self.supports_negative_prompt,
|
| 67 |
+
"supports_pretrained_split": self.supports_pretrained_split,
|
| 68 |
+
"text_encoder_type": type(self.text_encoder).__name__
|
| 69 |
+
}, f)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
#Text encoder/tokenizer saving is different depending on if we're using a larger pretrained model
|
| 73 |
+
if isinstance(self.text_encoder, TransformerModel):
|
| 74 |
+
# Save custom text encoder
|
| 75 |
+
if self.text_encoder is not None:
|
| 76 |
+
self.text_encoder.save_pretrained(os.path.join(save_directory, "text_encoder"))
|
| 77 |
+
else:
|
| 78 |
+
#Save pretrained tokenizer by name, so we can load from huggingface instead of saving a giant local model
|
| 79 |
+
text_encoder_info = {
|
| 80 |
+
"text_encoder_name": self.text_encoder.config.name_or_path,
|
| 81 |
+
"tokenizer_name": self.tokenizer.name_or_path,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
text_encoder_directory = os.path.join(save_directory, "text_encoder")
|
| 85 |
+
os.makedirs(text_encoder_directory, exist_ok=True)
|
| 86 |
+
|
| 87 |
+
with open(os.path.join(text_encoder_directory, "loading_info.json"), "w") as f:
|
| 88 |
+
json.dump(text_encoder_info, f)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@classmethod
|
| 93 |
+
def from_pretrained(cls, pretrained_model_path, **kwargs):
|
| 94 |
+
#from diffusers.utils import load_config, load_state_dict
|
| 95 |
+
# Load model_index.json
|
| 96 |
+
#model_index = load_config(pretrained_model_path)
|
| 97 |
+
|
| 98 |
+
# Load components manually
|
| 99 |
+
unet_path = os.path.join(pretrained_model_path, "unet")
|
| 100 |
+
unet = UNet2DConditionModel.from_pretrained(unet_path)
|
| 101 |
+
|
| 102 |
+
scheduler_path = os.path.join(pretrained_model_path, "scheduler")
|
| 103 |
+
# Have heard that DDIMScheduler might be faster for inference, though not necessarily better
|
| 104 |
+
scheduler = DDPMScheduler.from_pretrained(scheduler_path)
|
| 105 |
+
|
| 106 |
+
tokenizer = None
|
| 107 |
+
text_encoder_path = os.path.join(pretrained_model_path, "text_encoder")
|
| 108 |
+
|
| 109 |
+
if os.path.exists(text_encoder_path):
|
| 110 |
+
#Test for the new saving system, where we save a simple config file
|
| 111 |
+
if os.path.exists(os.path.join(text_encoder_path, "loading_info.json")):
|
| 112 |
+
with open(os.path.join(text_encoder_path, "loading_info.json"), "r") as f:
|
| 113 |
+
encoder_config = json.load(f)
|
| 114 |
+
|
| 115 |
+
text_encoder = AutoModel.from_pretrained(encoder_config['text_encoder_name'], trust_remote_code=True)
|
| 116 |
+
tokenizer = AutoTokenizer.from_pretrained(encoder_config['tokenizer_name'])
|
| 117 |
+
|
| 118 |
+
#Legacy loading system, loads models directly if the whole thing is saved in the directory
|
| 119 |
+
else:
|
| 120 |
+
try:
|
| 121 |
+
text_encoder = AutoModel.from_pretrained(text_encoder_path, local_files_only=True, trust_remote_code=True)
|
| 122 |
+
tokenizer = AutoTokenizer.from_pretrained(text_encoder_path, local_files_only=True)
|
| 123 |
+
except (ValueError, KeyError):
|
| 124 |
+
text_encoder = TransformerModel.from_pretrained(text_encoder_path)
|
| 125 |
+
tokenizer = text_encoder.tokenizer
|
| 126 |
+
else:
|
| 127 |
+
text_encoder = None
|
| 128 |
+
|
| 129 |
+
# Instantiate your pipeline
|
| 130 |
+
pipeline = cls(
|
| 131 |
+
unet=unet,
|
| 132 |
+
scheduler=scheduler,
|
| 133 |
+
text_encoder=text_encoder,
|
| 134 |
+
tokenizer=tokenizer,
|
| 135 |
+
**kwargs,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
#Loads block embeddings if present
|
| 139 |
+
block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
|
| 140 |
+
if os.path.exists(block_embeds_path):
|
| 141 |
+
pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
|
| 142 |
+
else:
|
| 143 |
+
pipeline.block_embeddings = None
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# Load supports_negative_prompt flag if present
|
| 147 |
+
config_path = os.path.join(pretrained_model_path, "pipeline_config.json")
|
| 148 |
+
if os.path.exists(config_path):
|
| 149 |
+
with open(config_path, "r") as f:
|
| 150 |
+
config = json.load(f)
|
| 151 |
+
pipeline.supports_negative_prompt = config.get("supports_negative_prompt", False)
|
| 152 |
+
pipeline.supports_pretrained_split = config.get("supports_pretrained_split", False)
|
| 153 |
+
return pipeline
|
| 154 |
+
|
| 155 |
+
# --- Handle batching for captions ---
|
| 156 |
+
def _prepare_text_batch(self, text: Optional[str | list[str]], batch_size: int, name: str) -> Optional[list[str]]:
|
| 157 |
+
if text is None:
|
| 158 |
+
return None
|
| 159 |
+
if isinstance(text, str):
|
| 160 |
+
return [text] * batch_size
|
| 161 |
+
if isinstance(text, list):
|
| 162 |
+
if len(text) == 1:
|
| 163 |
+
return text * batch_size
|
| 164 |
+
if len(text) != batch_size:
|
| 165 |
+
raise ValueError(f"{name} list length {len(text)} does not match batch_size {batch_size}")
|
| 166 |
+
return text
|
| 167 |
+
raise ValueError(f"{name} must be a string or list of strings")
|
| 168 |
+
|
| 169 |
+
def _prepare_initial_sample(self,
|
| 170 |
+
raw_latent_sample: Optional[torch.Tensor],
|
| 171 |
+
input_scene: Optional[torch.Tensor],
|
| 172 |
+
batch_size: int, height: int, width: int,
|
| 173 |
+
generator: Optional[torch.Generator]) -> torch.Tensor:
|
| 174 |
+
"""Prepare the initial sample for diffusion."""
|
| 175 |
+
|
| 176 |
+
sample_shape = (batch_size, self.unet.config.in_channels, height, width)
|
| 177 |
+
|
| 178 |
+
if raw_latent_sample is not None:
|
| 179 |
+
if input_scene is not None:
|
| 180 |
+
raise ValueError("Cannot provide both raw_latent_sample and input_scene")
|
| 181 |
+
sample = raw_latent_sample.to(self.device)
|
| 182 |
+
if sample.shape[1] != sample_shape[1]:
|
| 183 |
+
raise ValueError(f"Wrong number of channels in raw_latent_sample: Expected {self.unet.config.in_channels} but got {sample.shape[1]}")
|
| 184 |
+
if sample.shape[0] == 1 and batch_size > 1:
|
| 185 |
+
sample = sample.repeat(batch_size, 1, 1, 1)
|
| 186 |
+
elif sample.shape[0] != batch_size:
|
| 187 |
+
raise ValueError(f"raw_latent_sample batch size {sample.shape[0]} does not match batch_size {batch_size}")
|
| 188 |
+
elif input_scene is not None:
|
| 189 |
+
# input_scene can be (H, W) or (batch_size, H, W)
|
| 190 |
+
scene_tensor = torch.tensor(input_scene, dtype=torch.long, device=self.device)
|
| 191 |
+
if scene_tensor.dim() == 2:
|
| 192 |
+
# (H, W) -> repeat for batch
|
| 193 |
+
scene_tensor = scene_tensor.unsqueeze(0).repeat(batch_size, 1, 1)
|
| 194 |
+
elif scene_tensor.shape[0] == 1 and batch_size > 1:
|
| 195 |
+
scene_tensor = scene_tensor.repeat(batch_size, 1, 1)
|
| 196 |
+
elif scene_tensor.shape[0] != batch_size:
|
| 197 |
+
raise ValueError(f"input_scene batch size {scene_tensor.shape[0]} does not match batch_size {batch_size}")
|
| 198 |
+
# One-hot encode: (batch, H, W, C)
|
| 199 |
+
one_hot = F.one_hot(scene_tensor, num_classes=self.unet.config.in_channels).float()
|
| 200 |
+
# (batch, H, W, C) -> (batch, C, H, W)
|
| 201 |
+
sample = one_hot.permute(0, 3, 1, 2)
|
| 202 |
+
else:
|
| 203 |
+
# Start from random noise
|
| 204 |
+
sample = torch.randn(sample_shape, generator=generator, device=self.device)
|
| 205 |
+
|
| 206 |
+
return sample
|
| 207 |
+
|
| 208 |
+
def __call__(
|
| 209 |
+
self,
|
| 210 |
+
caption: Optional[str | list[str]] = None,
|
| 211 |
+
negative_prompt: Optional[str | list[str]] = None,
|
| 212 |
+
generator: Optional[torch.Generator] = None,
|
| 213 |
+
num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,
|
| 214 |
+
guidance_scale: float = common_settings.GUIDANCE_SCALE,
|
| 215 |
+
height: int = common_settings.MARIO_HEIGHT,
|
| 216 |
+
width: int = common_settings.MARIO_WIDTH,
|
| 217 |
+
raw_latent_sample: Optional[torch.FloatTensor] = None,
|
| 218 |
+
input_scene: Optional[torch.Tensor] = None,
|
| 219 |
+
output_type: str = "tensor",
|
| 220 |
+
batch_size: int = 1,
|
| 221 |
+
show_progress_bar: bool = True,
|
| 222 |
+
) -> PipelineOutput:
|
| 223 |
+
"""Generate a batch of images based on text input using the diffusion model.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
caption: Text description(s) of the desired output. Can be a string or list of strings.
|
| 227 |
+
negative_prompt: Text description(s) of what should not appear in the output. String or list.
|
| 228 |
+
generator: Random number generator for reproducibility.
|
| 229 |
+
num_inference_steps: Number of denoising steps (more = higher quality, slower).
|
| 230 |
+
guidance_scale: How strongly the generation follows the text prompt (higher = stronger).
|
| 231 |
+
height: Height of generated image in tiles.
|
| 232 |
+
width: Width of generated image in tiles.
|
| 233 |
+
raw_latent_sample: Optional starting point for diffusion instead of random noise.
|
| 234 |
+
Must have correct number of channels matching the UNet.
|
| 235 |
+
input_scene: Optional 2D or 3D int tensor where each value corresponds to a tile type.
|
| 236 |
+
Will be converted to one-hot encoding as starting point.
|
| 237 |
+
output_type: Currently only "tensor" is supported.
|
| 238 |
+
batch_size: Number of samples to generate in parallel.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
PipelineOutput containing the generated image tensor (batch_size, ...).
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
# I would like to simplify the code to this, but the AI suggestion didn't work, and
|
| 245 |
+
# I did not feel good just pasting it all in. Will need to tackle it bit by bit.
|
| 246 |
+
|
| 247 |
+
# if caption is not None and self.text_encoder is None:
|
| 248 |
+
# raise ValueError("Text encoder required for conditional generation")
|
| 249 |
+
|
| 250 |
+
# self.unet.eval()
|
| 251 |
+
# if self.text_encoder is not None:
|
| 252 |
+
# self.text_encoder.to(self.device)
|
| 253 |
+
# self.text_encoder.eval()
|
| 254 |
+
#
|
| 255 |
+
# with torch.no_grad():
|
| 256 |
+
# # Process text inputs
|
| 257 |
+
# captions = self.prepare_text_batch(caption, batch_size, "caption")
|
| 258 |
+
# negatives = self.prepare_text_batch(negative_prompt, batch_size, "negative_prompt")
|
| 259 |
+
|
| 260 |
+
# # Get embeddings
|
| 261 |
+
# text_embeddings = self.prepare_embeddings(captions, negatives, batch_size)
|
| 262 |
+
#
|
| 263 |
+
# # Set up initial latent state
|
| 264 |
+
# sample = self.prepare_initial_sample(raw_latent_sample, input_scene,
|
| 265 |
+
# batch_size, height, width, generator)
|
| 266 |
+
|
| 267 |
+
# # Run diffusion process
|
| 268 |
+
# sample = self.run_diffusion(sample, text_embeddings, num_inference_steps,
|
| 269 |
+
# guidance_scale, generator, show_progress_bar,
|
| 270 |
+
# has_caption=caption is not None,
|
| 271 |
+
# has_negative=negative_prompt is not None)
|
| 272 |
+
|
| 273 |
+
# # Format output
|
| 274 |
+
# if output_type == "tensor":
|
| 275 |
+
# sample = F.softmax(sample, dim=1)
|
| 276 |
+
# else:
|
| 277 |
+
# raise ValueError(f"Unsupported output type: {output_type}")
|
| 278 |
+
|
| 279 |
+
# return PipelineOutput(images=sample)
|
| 280 |
+
|
| 281 |
+
# Validate text encoder if we need it
|
| 282 |
+
if caption is not None and self.text_encoder is None:
|
| 283 |
+
raise ValueError("Text encoder is required for conditional generation")
|
| 284 |
+
|
| 285 |
+
self.unet.eval()
|
| 286 |
+
if self.text_encoder is not None:
|
| 287 |
+
self.text_encoder.to(self.device)
|
| 288 |
+
self.text_encoder.eval()
|
| 289 |
+
|
| 290 |
+
with torch.no_grad():
|
| 291 |
+
captions = self._prepare_text_batch(caption, batch_size, "caption")
|
| 292 |
+
negatives = self._prepare_text_batch(negative_prompt, batch_size, "negative_prompt")
|
| 293 |
+
|
| 294 |
+
# --- Prepare text embeddings ---
|
| 295 |
+
if(isinstance(self.text_encoder, TransformerModel)):
|
| 296 |
+
text_embeddings = text_model.get_embeddings(batch_size=batch_size,
|
| 297 |
+
tokenizer=self.text_encoder.tokenizer,
|
| 298 |
+
text_encoder=self.text_encoder,
|
| 299 |
+
captions=captions,
|
| 300 |
+
neg_captions=negatives,
|
| 301 |
+
device=self.device)
|
| 302 |
+
else: #Case for the pre-trained text encoder
|
| 303 |
+
if(self.supports_pretrained_split): #If we have a split flag incorporated
|
| 304 |
+
text_embeddings = st_helper.get_embeddings_split(batch_size = batch_size,
|
| 305 |
+
tokenizer=self.tokenizer,
|
| 306 |
+
model=self.text_encoder,
|
| 307 |
+
captions=captions,
|
| 308 |
+
neg_captions=negatives,
|
| 309 |
+
device=self.device)
|
| 310 |
+
else:
|
| 311 |
+
text_embeddings = st_helper.get_embeddings(batch_size = batch_size,
|
| 312 |
+
tokenizer=self.tokenizer,
|
| 313 |
+
model=self.text_encoder,
|
| 314 |
+
captions=captions,
|
| 315 |
+
neg_captions=negatives,
|
| 316 |
+
device=self.device)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# --- Set up initial latent state ---
|
| 320 |
+
sample = self._prepare_initial_sample(raw_latent_sample, input_scene,
|
| 321 |
+
batch_size, height, width, generator)
|
| 322 |
+
|
| 323 |
+
# --- Set up diffusion process ---
|
| 324 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 325 |
+
|
| 326 |
+
# Denoising loop
|
| 327 |
+
iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
|
| 328 |
+
for t in iterator:
|
| 329 |
+
# Handle conditional generation
|
| 330 |
+
if captions is not None:
|
| 331 |
+
if negatives is not None:
|
| 332 |
+
# Three copies for negative prompt guidance
|
| 333 |
+
model_input = torch.cat([sample, sample, sample], dim=0)
|
| 334 |
+
else:
|
| 335 |
+
# Two copies for standard classifier-free guidance
|
| 336 |
+
model_input = torch.cat([sample, sample], dim=0)
|
| 337 |
+
else:
|
| 338 |
+
model_input = sample
|
| 339 |
+
|
| 340 |
+
# Predict noise residual
|
| 341 |
+
model_kwargs = {"encoder_hidden_states": text_embeddings}
|
| 342 |
+
noise_pred = self.unet(model_input, t, **model_kwargs).sample
|
| 343 |
+
|
| 344 |
+
# Apply guidance
|
| 345 |
+
if captions is not None:
|
| 346 |
+
if negatives is not None:
|
| 347 |
+
# Split predictions for negative, unconditional, and text-conditional
|
| 348 |
+
noise_pred_neg, noise_pred_uncond, noise_pred_text = noise_pred.chunk(3)
|
| 349 |
+
noise_pred_guided = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 350 |
+
noise_pred = noise_pred_guided - guidance_scale * (noise_pred_neg - noise_pred_uncond)
|
| 351 |
+
else:
|
| 352 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 353 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 354 |
+
|
| 355 |
+
# Compute previous sample: x_{t-1} = scheduler(x_t, noise_pred)
|
| 356 |
+
sample = self.scheduler.step(noise_pred, t, sample, generator=generator).prev_sample
|
| 357 |
+
|
| 358 |
+
# Convert to output format
|
| 359 |
+
if output_type == "tensor":
|
| 360 |
+
if self.block_embeddings is not None:
|
| 361 |
+
sample = get_scene_from_embeddings(sample, self.block_embeddings)
|
| 362 |
+
else:
|
| 363 |
+
# Apply softmax to get probabilities for each tile type
|
| 364 |
+
sample = F.softmax(sample, dim=1)
|
| 365 |
+
sample = sample.detach().cpu()
|
| 366 |
+
else:
|
| 367 |
+
raise ValueError(f"Unsupported output type: {output_type}")
|
| 368 |
+
|
| 369 |
+
return PipelineOutput(images=sample)
|
| 370 |
+
|
| 371 |
+
def print_unet_architecture(self):
|
| 372 |
+
"""Prints the architecture of the UNet model."""
|
| 373 |
+
print(self.unet)
|
| 374 |
+
|
| 375 |
+
def print_text_encoder_architecture(self):
|
| 376 |
+
"""Prints the architecture of the text encoder model, if it exists."""
|
| 377 |
+
if self.text_encoder is not None:
|
| 378 |
+
print(self.text_encoder)
|
| 379 |
+
else:
|
| 380 |
+
print("No text encoder is set.")
|
| 381 |
+
|
| 382 |
+
def save_unet_architecture_pdf(self, height, width, filename="unet_architecture", batch_size=1, device=None):
|
| 383 |
+
"""
|
| 384 |
+
Have to separately install torchview for this to work
|
| 385 |
+
|
| 386 |
+
Saves a visualization of the UNet architecture as a PDF using torchview.
|
| 387 |
+
Args:
|
| 388 |
+
height: Height of the dummy input.
|
| 389 |
+
width: Width of the dummy input.
|
| 390 |
+
filename: Output PDF filename.
|
| 391 |
+
batch_size: Batch size for dummy input.
|
| 392 |
+
device: Device to run the dummy input on (defaults to pipeline device).
|
| 393 |
+
"""
|
| 394 |
+
from torchview import draw_graph
|
| 395 |
+
import graphviz
|
| 396 |
+
|
| 397 |
+
if device is None:
|
| 398 |
+
device = self.device if hasattr(self, 'device') else 'cpu'
|
| 399 |
+
in_channels = self.unet.config.in_channels if hasattr(self.unet, 'config') else 1
|
| 400 |
+
sample_shape = tuple([batch_size, in_channels, height, width])
|
| 401 |
+
|
| 402 |
+
dummy_x = torch.randn(size=sample_shape, device=device)
|
| 403 |
+
dummy_t = torch.tensor([0] * batch_size, dtype=torch.long, device=device)
|
| 404 |
+
|
| 405 |
+
# Prepare dummy text embedding (match what your UNet expects)
|
| 406 |
+
if hasattr(self.unet, 'config') and hasattr(self.unet.config, 'cross_attention_dim'):
|
| 407 |
+
cross_attention_dim = self.unet.config.cross_attention_dim
|
| 408 |
+
else:
|
| 409 |
+
cross_attention_dim = 128 # fallback
|
| 410 |
+
encoder_hidden_states = torch.randn(batch_size, 1, cross_attention_dim, device=device)
|
| 411 |
+
|
| 412 |
+
self.unet.eval()
|
| 413 |
+
inputs = (dummy_x, dummy_t, encoder_hidden_states)
|
| 414 |
+
#self.unet.down_blocks = self.unet.down_blocks[:2]
|
| 415 |
+
|
| 416 |
+
graph = draw_graph(
|
| 417 |
+
model=self.unet,
|
| 418 |
+
input_data=inputs,
|
| 419 |
+
expand_nested=False,
|
| 420 |
+
#enable_output_shape=True,
|
| 421 |
+
#roll_out="nested",
|
| 422 |
+
depth=1
|
| 423 |
+
)
|
| 424 |
+
#graph.visual_graph.engine = "neato"
|
| 425 |
+
graph.visual_graph.attr(#rankdir="LR",
|
| 426 |
+
nodesep="0.1", # decrease space between nodes in the same rank (default ~0.25)
|
| 427 |
+
ranksep="0.2", # decrease space between ranks (default ~0.5)
|
| 428 |
+
concentrate="true" # merge edges between nodes in the same rank
|
| 429 |
+
)
|
| 430 |
+
graph.visual_graph.node_attr.update(
|
| 431 |
+
shape="rectangle",
|
| 432 |
+
width="1.5", # narrow width
|
| 433 |
+
height="0.5" # taller height to make vertical rectangles
|
| 434 |
+
#fixedsize="true"
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
|
| 438 |
+
graph.visual_graph.save('unet_architecture.dot')
|
| 439 |
+
|
| 440 |
+
# Save the graph to a PDF file
|
| 441 |
+
print(f"UNet architecture saved to {filename}")
|
| 442 |
+
|
models/text_model.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from xml.parsers.expat import model
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
from safetensors.torch import save_file, load_file
|
| 9 |
+
from tokenizer import Tokenizer
|
| 10 |
+
|
| 11 |
+
def get_embeddings(batch_size, tokenizer, text_encoder, captions=None, neg_captions=None, device='cpu'):
|
| 12 |
+
max_length = text_encoder.max_seq_length
|
| 13 |
+
empty_ids = encode_token_captions([""] * batch_size, tokenizer, max_length, device=device)
|
| 14 |
+
embeddings = text_encoder.get_embeddings(empty_ids)
|
| 15 |
+
|
| 16 |
+
if(captions is not None):
|
| 17 |
+
caption_ids = encode_token_captions(captions, tokenizer, max_length, device=device)
|
| 18 |
+
caption_embeddings = text_encoder.get_embeddings(caption_ids)
|
| 19 |
+
embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
|
| 20 |
+
|
| 21 |
+
if(neg_captions is not None):
|
| 22 |
+
neg_ids = encode_token_captions(neg_captions, tokenizer, max_length, device=device)
|
| 23 |
+
neg_embeddings = text_encoder.get_embeddings(neg_ids)
|
| 24 |
+
embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
|
| 25 |
+
|
| 26 |
+
return embeddings.to(device)
|
| 27 |
+
|
| 28 |
+
def encode_token_captions(captions, tokenizer, max_length, device='cpu'):
|
| 29 |
+
caption_ids = []
|
| 30 |
+
for caption in captions:
|
| 31 |
+
tokens = tokenizer.encode(caption)
|
| 32 |
+
caption_tokens = tokenizer.pad_sequence(tokens, max_length)
|
| 33 |
+
caption_ids.append(torch.tensor(caption_tokens, dtype=torch.long).unsqueeze(0))
|
| 34 |
+
return torch.cat(caption_ids, dim=0).to(device)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Transformer model for MLM training
|
| 45 |
+
|
| 46 |
+
class TransformerModel(nn.Module):
|
| 47 |
+
def __init__(self, vocab_size, embedding_dim, hidden_dim, tokenizer=None, num_heads=8, num_layers=4, max_seq_length=100):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.embedding_dim = embedding_dim
|
| 50 |
+
self.vocab_size = vocab_size
|
| 51 |
+
self.hidden_dim = hidden_dim
|
| 52 |
+
self.num_heads = num_heads
|
| 53 |
+
self.num_layers = num_layers
|
| 54 |
+
self.max_seq_length = max_seq_length
|
| 55 |
+
|
| 56 |
+
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
| 57 |
+
self.positional_encoding = self.create_positional_encoding(max_seq_length, embedding_dim)
|
| 58 |
+
|
| 59 |
+
encoder_layers = nn.TransformerEncoderLayer(
|
| 60 |
+
d_model=embedding_dim,
|
| 61 |
+
nhead=num_heads,
|
| 62 |
+
dim_feedforward=hidden_dim,
|
| 63 |
+
batch_first=True
|
| 64 |
+
)
|
| 65 |
+
self.transformer = nn.TransformerEncoder(encoder_layers, num_layers)
|
| 66 |
+
self.fc = nn.Linear(embedding_dim, vocab_size)
|
| 67 |
+
|
| 68 |
+
self.tokenizer = tokenizer
|
| 69 |
+
|
| 70 |
+
def create_positional_encoding(self, max_seq_length, embedding_dim):
|
| 71 |
+
# The implementation uses a sinusoidal positional encoding, which creates a unique pattern for each position in the sequence.
|
| 72 |
+
# The frequencies create unique values, the sin/cos bounds values
|
| 73 |
+
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
|
| 74 |
+
# Creates a set of divisors that create different frequencies
|
| 75 |
+
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
|
| 76 |
+
pe = torch.zeros(max_seq_length, embedding_dim)
|
| 77 |
+
# Even dimensions use sin, odd dimensions use cos
|
| 78 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 79 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 80 |
+
return pe.unsqueeze(0)
|
| 81 |
+
|
| 82 |
+
def get_embeddings(self, x):
|
| 83 |
+
""" This gets the actual latent embedding vectors """
|
| 84 |
+
# Ensure positional encoding is on the same device as input
|
| 85 |
+
pe = self.positional_encoding[:, :x.size(1), :].to(x.device)
|
| 86 |
+
# Embed input and add positional encoding
|
| 87 |
+
embedded = self.embedding(x) + pe
|
| 88 |
+
return self.transformer(embedded)
|
| 89 |
+
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
""" This gets the token within the vocabulary """
|
| 92 |
+
transformer_out = self.get_embeddings(x)
|
| 93 |
+
# Project to vocabulary size
|
| 94 |
+
return self.fc(transformer_out)
|
| 95 |
+
|
| 96 |
+
def save_pretrained(self, save_directory):
|
| 97 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 98 |
+
|
| 99 |
+
config = {
|
| 100 |
+
"vocab_size": self.vocab_size,
|
| 101 |
+
"embedding_dim": self.embedding_dim,
|
| 102 |
+
"hidden_dim": self.hidden_dim,
|
| 103 |
+
"num_heads": self.num_heads,
|
| 104 |
+
"num_layers": self.num_layers,
|
| 105 |
+
"max_seq_length": self.max_seq_length,
|
| 106 |
+
}
|
| 107 |
+
with open(os.path.join(save_directory, "config.json"), "w") as f:
|
| 108 |
+
json.dump(config, f)
|
| 109 |
+
|
| 110 |
+
# Save model weights
|
| 111 |
+
save_file(self.state_dict(), os.path.join(save_directory, "model.safetensors"))
|
| 112 |
+
|
| 113 |
+
# Save tokenizer if present
|
| 114 |
+
if self.tokenizer is not None:
|
| 115 |
+
self.tokenizer.save(os.path.join(save_directory, "tokenizer.pkl"))
|
| 116 |
+
|
| 117 |
+
@classmethod
|
| 118 |
+
def from_pretrained(cls, load_directory):
|
| 119 |
+
with open(os.path.join(load_directory, "config.json")) as f:
|
| 120 |
+
config = json.load(f)
|
| 121 |
+
|
| 122 |
+
model = cls(**config)
|
| 123 |
+
|
| 124 |
+
# Load weights
|
| 125 |
+
state_dict = load_file(os.path.join(load_directory, "model.safetensors"))
|
| 126 |
+
model.load_state_dict(state_dict)
|
| 127 |
+
|
| 128 |
+
# Load tokenizer if available
|
| 129 |
+
tokenizer_path = os.path.join(load_directory, "tokenizer.pkl")
|
| 130 |
+
if os.path.exists(tokenizer_path):
|
| 131 |
+
tokenizer = Tokenizer()
|
| 132 |
+
tokenizer.load(tokenizer_path)
|
| 133 |
+
model.tokenizer = tokenizer
|
| 134 |
+
|
| 135 |
+
return model
|
| 136 |
+
|
| 137 |
+
def print_architecture(self, inputs=None):
|
| 138 |
+
parser = argparse.ArgumentParser()
|
| 139 |
+
parser.add_argument("--model_path", type=str, required=True, help="Path to trained transformer model")
|
| 140 |
+
parser.add_argument("--json", type=str, default="SMB1_LevelsAndCaptions-regular-test.json", help="Path to dataset json file")
|
| 141 |
+
parser.add_argument("--num_samples", type=int, default=10, help="Number of captions to evaluate")
|
| 142 |
+
parser.add_argument("--mask_prob", type=float, default=0.15, help="Probability of masking each token")
|
| 143 |
+
|
| 144 |
+
parser.add_argument("--compare_checkpoints", action="store_true", default=False, help="Run comparison across all model checkpoints")
|
| 145 |
+
args = parser.parse_args()
|
| 146 |
+
|
| 147 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 148 |
+
model = TransformerModel.from_pretrained(args.model_path).to(device)
|
| 149 |
+
print(f"Loaded model from {args.model_path}")
|
| 150 |
+
|
| 151 |
+
import os
|
| 152 |
+
import re
|
| 153 |
+
import json
|
| 154 |
+
import matplotlib.pyplot as plt
|
| 155 |
+
from torchview import draw_graph
|
| 156 |
+
import graphviz
|
| 157 |
+
|
| 158 |
+
graph = draw_graph(
|
| 159 |
+
model=model,
|
| 160 |
+
input_data=inputs,
|
| 161 |
+
expand_nested=False,
|
| 162 |
+
#enable_output_shape=True,
|
| 163 |
+
#roll_out="nested",
|
| 164 |
+
depth=1
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Save plot
|
| 168 |
+
filename = 'mlm_architecture'
|
| 169 |
+
graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
|
| 170 |
+
#graph.visual_graph.save('unet_architecture.dot')
|
| 171 |
+
|
| 172 |
+
def save_architecture_pdf(self, filename="transformer_architecture.pdf", input_length=32):
|
| 173 |
+
"""Save a visualization of the model architecture as a PDF using torchview."""
|
| 174 |
+
try:
|
| 175 |
+
from torchview import draw_graph
|
| 176 |
+
except ImportError:
|
| 177 |
+
raise ImportError("torchview is required for model visualization. Install with 'pip install torchview'.")
|
| 178 |
+
import torch
|
| 179 |
+
import os
|
| 180 |
+
# Create a dummy input of the correct type for the model
|
| 181 |
+
captions = ["full floor. two coins. one pipe.", "floor with two gaps. one cannon. many enemies."]
|
| 182 |
+
tensor = encode_token_captions(captions, self.tokenizer, self.max_seq_length, device=next(self.parameters()).device)
|
| 183 |
+
input_length = tensor.size(1) if tensor.dim() > 1 else self.max_seq_length
|
| 184 |
+
|
| 185 |
+
num_tokens_list = [len(self.tokenizer.encode(c)) for c in captions]
|
| 186 |
+
input_length = max(num_tokens_list) if num_tokens_list else input_length
|
| 187 |
+
dummy_input = torch.zeros((1, input_length), dtype=torch.long, device=next(self.parameters()).device)
|
| 188 |
+
|
| 189 |
+
# Draw the graph and save as PNG
|
| 190 |
+
graph = draw_graph(self, input_data=dummy_input, expand_nested=True, save_graph=True, filename=filename.replace('.pdf',''), directory=".", depth=2)
|
| 191 |
+
png_file = filename.replace('.pdf', '.png')
|
| 192 |
+
# Convert PNG to PDF
|
| 193 |
+
if os.path.exists(png_file):
|
| 194 |
+
try:
|
| 195 |
+
from PIL import Image
|
| 196 |
+
im = Image.open(png_file)
|
| 197 |
+
im.save(filename, "PDF", resolution=100.0)
|
| 198 |
+
print(f"Saved architecture PDF to {filename}")
|
| 199 |
+
# Optionally, remove the PNG file
|
| 200 |
+
os.remove(png_file)
|
| 201 |
+
except ImportError:
|
| 202 |
+
print(f"PIL not installed. Architecture saved as PNG: {png_file}")
|
| 203 |
+
except Exception as e:
|
| 204 |
+
print(f"Could not convert PNG to PDF: {e}")
|
| 205 |
+
else:
|
| 206 |
+
print(f"Could not find PNG file to convert: {png_file}")
|
util/common_settings.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
NUM_INFERENCE_STEPS = 30
|
| 3 |
+
GUIDANCE_SCALE = 7.5
|
| 4 |
+
|
| 5 |
+
MARIO_HEIGHT = 16
|
| 6 |
+
MARIO_WIDTH = 16
|
| 7 |
+
|
| 8 |
+
MARIO_TILE_PIXEL_DIM = 16
|
| 9 |
+
MARIO_TILE_COUNT = 13
|
| 10 |
+
|
| 11 |
+
LR_HEIGHT = 32
|
| 12 |
+
LR_WIDTH = 32
|
| 13 |
+
|
| 14 |
+
LR_TILE_PIXEL_DIM = 8
|
| 15 |
+
LR_TILE_COUNT = 8
|
| 16 |
+
|
| 17 |
+
MEGAMAN_HEIGHT = 14
|
| 18 |
+
MEGAMAN_WIDTH = 16
|
util/naming_conventions.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_name_map = [
|
| 2 |
+
("Mar1and2-conditional-regular", "MLM-regular"),
|
| 3 |
+
("Mar1and2-conditional-absence", "MLM-absence"),
|
| 4 |
+
("Mar1and2-conditional-negative", "MLM-negative"),
|
| 5 |
+
("Mar1and2-conditional-MiniLM-regular", "MiniLM-single-regular"),
|
| 6 |
+
("Mar1and2-conditional-MiniLM-absence", "MiniLM-single-absence"),
|
| 7 |
+
("Mar1and2-conditional-MiniLM-negative", "MiniLM-single-negative"),
|
| 8 |
+
("Mar1and2-conditional-MiniLMsplit-regular", "MiniLM-multiple-regular"),
|
| 9 |
+
("Mar1and2-conditional-MiniLMsplit-absence", "MiniLM-multiple-absence"),
|
| 10 |
+
("Mar1and2-conditional-MiniLMsplit-negative", "MiniLM-multiple-negative"),
|
| 11 |
+
("Mar1and2-conditional-GTE-regular", "GTE-single-regular"),
|
| 12 |
+
("Mar1and2-conditional-GTE-absence", "GTE-single-absence"),
|
| 13 |
+
("Mar1and2-conditional-GTE-negative", "GTE-single-negative"),
|
| 14 |
+
("Mar1and2-conditional-GTEsplit-regular", "GTE-multiple-regular"),
|
| 15 |
+
("Mar1and2-conditional-GTEsplit-absence", "GTE-multiple-absence"),
|
| 16 |
+
("Mar1and2-conditional-GTEsplit-negative", "GTE-multiple-negative"),
|
| 17 |
+
("Mar1and2-fdm-MiniLM-regular", "FDM-MiniLM-regular"),
|
| 18 |
+
("Mar1and2-fdm-MiniLM-absence", "FDM-MiniLM-absence"),
|
| 19 |
+
("Mar1and2-fdm-GTE-regular", "FDM-GTE-regular"),
|
| 20 |
+
("Mar1and2-fdm-GTE-absence", "FDM-GTE-absence"),
|
| 21 |
+
("Mar1and2-wgan", "WGAN"),
|
| 22 |
+
("Mar1and2-unconditional", "Unconditional"),
|
| 23 |
+
("MarioGPT_metrics", "MarioGPT"),
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
def get_model_name_map_and_order():
|
| 27 |
+
mapping = dict(model_name_map)
|
| 28 |
+
order = [v for k, v in model_name_map]
|
| 29 |
+
return mapping, order
|
util/plotter.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Track changes in loss and learning rate during execution
|
| 2 |
+
import argparse
|
| 3 |
+
import matplotlib
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
import json
|
| 8 |
+
import tempfile
|
| 9 |
+
import shutil
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def parse_args():
|
| 14 |
+
parser = argparse.ArgumentParser(description="Train a text-conditional diffusion model for tile-based level generation")
|
| 15 |
+
|
| 16 |
+
# Dataset args
|
| 17 |
+
parser.add_argument("--log_file", type=str, default=None, help="The the filepath of the file to get the data from")
|
| 18 |
+
parser.add_argument("--left_key", type=str, default=None, help="The key for the left y-axis")
|
| 19 |
+
parser.add_argument("--right_key", type=str, default=None, help="The key for the right y-axis")
|
| 20 |
+
parser.add_argument("--left_label", type=str, default=None, help="The label for the left y-axis")
|
| 21 |
+
parser.add_argument("--right_label", type=str, default=None, help="The label for the right y-axis")
|
| 22 |
+
parser.add_argument("--output_png", type=str, default="output.png", help="The output png file")
|
| 23 |
+
parser.add_argument("--update_interval", type=int, default=1.0, help="The update inteval in epochs")
|
| 24 |
+
parser.add_argument("--start_point", type=int, default=None, help="The start point for the plot")
|
| 25 |
+
|
| 26 |
+
return parser.parse_args()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main():
|
| 30 |
+
args = parse_args()
|
| 31 |
+
|
| 32 |
+
log_file = args.log_file
|
| 33 |
+
left_key = args.left_key
|
| 34 |
+
right_key = args.right_key
|
| 35 |
+
left_label = args.left_label
|
| 36 |
+
right_label = args.right_label
|
| 37 |
+
output_png = args.output_png
|
| 38 |
+
update_interval = args.update_interval
|
| 39 |
+
start_point = args.start_point
|
| 40 |
+
|
| 41 |
+
general_update_plot(log_file, left_key, right_key, left_label, right_label, output_png, update_interval=update_interval, startPoint=start_point)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def general_update_plot(log_file, left_key, right_key, left_label, right_label, output_png, update_interval=1.0, startPoint=None):
|
| 45 |
+
log_dir = os.path.dirname(log_file)
|
| 46 |
+
|
| 47 |
+
# Create figure here and ensure it's closed
|
| 48 |
+
fig = plt.figure(figsize=(10, 6))
|
| 49 |
+
ax = fig.add_subplot(111)
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
if os.path.exists(log_file):
|
| 53 |
+
with open(log_file, 'r') as f:
|
| 54 |
+
data = [json.loads(line) for line in f if line.strip()]
|
| 55 |
+
|
| 56 |
+
if not data:
|
| 57 |
+
return
|
| 58 |
+
|
| 59 |
+
if startPoint is not None:
|
| 60 |
+
data = [entry for entry in data if entry.get('epoch', 0) >= startPoint]
|
| 61 |
+
|
| 62 |
+
if not data:
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
epochs = [entry.get('epoch', 0) for entry in data]
|
| 66 |
+
left = [entry.get(left_key, 0) for entry in data]
|
| 67 |
+
|
| 68 |
+
# For right axis (e.g., lr), only include points where right_key exists
|
| 69 |
+
right_points = [(entry.get('epoch', 0), entry.get(right_key))
|
| 70 |
+
for entry in data if right_key in entry]
|
| 71 |
+
if right_points:
|
| 72 |
+
right_epochs, right_values = zip(*right_points)
|
| 73 |
+
else:
|
| 74 |
+
right_epochs, right_values = [], []
|
| 75 |
+
|
| 76 |
+
# Clear axis
|
| 77 |
+
ax.clear()
|
| 78 |
+
|
| 79 |
+
# Plot both metrics on the same axis
|
| 80 |
+
ax.plot(epochs, left, 'b-', label=left_label)
|
| 81 |
+
if right_epochs:
|
| 82 |
+
ax.plot(right_epochs, right_values, 'r-', label=right_label)
|
| 83 |
+
|
| 84 |
+
ax.set_xlabel('Epoch')
|
| 85 |
+
ax.set_ylabel(left_label) # "Loss" as y-axis label
|
| 86 |
+
ax.set_title('Training Progress')
|
| 87 |
+
ax.legend(loc='upper left')
|
| 88 |
+
#Limit x-axis to startPoint if provided
|
| 89 |
+
if startPoint is not None:
|
| 90 |
+
ax.set_xlim(left=startPoint)
|
| 91 |
+
fig.tight_layout()
|
| 92 |
+
|
| 93 |
+
# Use the stored base directory instead of getting it from log_file
|
| 94 |
+
if os.path.isabs(output_png) or os.path.dirname(output_png):
|
| 95 |
+
output_path = output_png
|
| 96 |
+
else:
|
| 97 |
+
output_path = os.path.join(log_dir, output_png)
|
| 98 |
+
|
| 99 |
+
save_figure_safely(fig, output_path)
|
| 100 |
+
finally:
|
| 101 |
+
plt.close(fig) # Ensure figure is closed even if an error occurs
|
| 102 |
+
|
| 103 |
+
def save_figure_safely(fig, output_path):
|
| 104 |
+
"""Save figure to a temporary file first, then move it to the final location"""
|
| 105 |
+
output_path = str(Path(output_path)) # Convert to string path
|
| 106 |
+
|
| 107 |
+
# Create temporary file with .png extension
|
| 108 |
+
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
|
| 109 |
+
tmp_path = tmp_file.name
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
# Save to temporary file
|
| 113 |
+
fig.savefig(tmp_path)
|
| 114 |
+
|
| 115 |
+
# Create output directory if it doesn't exist
|
| 116 |
+
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
| 117 |
+
|
| 118 |
+
# Try to move the file to final destination
|
| 119 |
+
# If move fails, try to copy and then delete
|
| 120 |
+
try:
|
| 121 |
+
shutil.move(tmp_path, output_path)
|
| 122 |
+
except OSError:
|
| 123 |
+
shutil.copy2(tmp_path, output_path)
|
| 124 |
+
os.unlink(tmp_path)
|
| 125 |
+
except Exception as e:
|
| 126 |
+
# Clean up temporary file if anything goes wrong
|
| 127 |
+
if os.path.exists(tmp_path):
|
| 128 |
+
os.unlink(tmp_path)
|
| 129 |
+
raise e
|
| 130 |
+
|
| 131 |
+
class Plotter:
|
| 132 |
+
def __init__(self, log_file, update_interval=1.0, left_key='loss', right_key='lr',
|
| 133 |
+
left_label='Loss', right_label='Learning Rate', output_png='training_progress.png'):
|
| 134 |
+
self.log_dir = os.path.dirname(log_file)
|
| 135 |
+
self.log_file = log_file
|
| 136 |
+
self.update_interval = update_interval
|
| 137 |
+
self.running = True
|
| 138 |
+
self.output_png = output_png
|
| 139 |
+
self.left_key = left_key
|
| 140 |
+
self.right_key = right_key
|
| 141 |
+
self.left_label = left_label
|
| 142 |
+
self.right_label = right_label
|
| 143 |
+
|
| 144 |
+
matplotlib.use('Agg')
|
| 145 |
+
|
| 146 |
+
def __enter__(self):
|
| 147 |
+
return self
|
| 148 |
+
|
| 149 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 150 |
+
self.stop_plotting()
|
| 151 |
+
|
| 152 |
+
def __del__(self):
|
| 153 |
+
self.stop_plotting()
|
| 154 |
+
|
| 155 |
+
def update_plot(self):
|
| 156 |
+
general_update_plot(self.log_file, self.left_key, self.right_key,
|
| 157 |
+
self.left_label, self.right_label, self.output_png,
|
| 158 |
+
update_interval=self.update_interval)
|
| 159 |
+
|
| 160 |
+
def start_plotting(self):
|
| 161 |
+
print("Starting plotting in background")
|
| 162 |
+
while self.running:
|
| 163 |
+
self.update_plot()
|
| 164 |
+
time.sleep(self.update_interval)
|
| 165 |
+
|
| 166 |
+
def stop_plotting(self):
|
| 167 |
+
if hasattr(self, 'running'): # Check if already stopped
|
| 168 |
+
self.running = False
|
| 169 |
+
self.update_plot()
|
| 170 |
+
print("Plotting stopped")
|
| 171 |
+
|
| 172 |
+
if __name__ == "__main__":
|
| 173 |
+
main()
|
util/sampler.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import subprocess
|
| 8 |
+
import tempfile
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from PIL.Image import Image
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from mario_gpt.lm.base import BaseMarioLM
|
| 18 |
+
from mario_gpt.prompter import Prompter
|
| 19 |
+
from mario_gpt.simulator import Simulator
|
| 20 |
+
from mario_gpt.utils import (
|
| 21 |
+
convert_level_to_png,
|
| 22 |
+
load_level,
|
| 23 |
+
save_level,
|
| 24 |
+
trim_level,
|
| 25 |
+
view_level,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def scene_to_ascii(scene, id_to_char, shorten: bool = True) -> List[str]:
|
| 29 |
+
"""
|
| 30 |
+
Convert JSON scene files from a list of lists of ints
|
| 31 |
+
to a list of ASCII strings using id_to_char mapping.
|
| 32 |
+
If shorten is True, only the last 15 rows are kept.
|
| 33 |
+
Args:
|
| 34 |
+
scene: List[List[int]] - 2D array of tile IDs
|
| 35 |
+
id_to_char: Dict[int, str] - mapping from tile ID to ASCII character
|
| 36 |
+
shorten: bool - If True, will shorten the output to only include the first 15 rows
|
| 37 |
+
so A* Mario (for SNES graphics) to run without glitching
|
| 38 |
+
Returns:
|
| 39 |
+
List[str]: List of strings, each representing a row in ASCII
|
| 40 |
+
"""
|
| 41 |
+
if shorten and len(scene) > 15:
|
| 42 |
+
scene = scene[-15:] # Keep only the last 15 rows
|
| 43 |
+
return ["".join(id_to_char[num] for num in row) for row in scene]
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class SampleOutput:
|
| 47 |
+
level: Optional[List[str]]
|
| 48 |
+
prompt: Optional[str] = None
|
| 49 |
+
img: Optional[Image] = None
|
| 50 |
+
sample_predictions_str: Optional[List[str]] = None
|
| 51 |
+
sample_predictions_img: Optional[Image] = None
|
| 52 |
+
level_tensor: Optional[torch.Tensor] = None
|
| 53 |
+
sample_predictions_tensor: Optional[torch.Tensor] = None
|
| 54 |
+
# Uses MarioEval graphics for rendering levels when True
|
| 55 |
+
use_snes_graphics: bool = False
|
| 56 |
+
|
| 57 |
+
@classmethod
|
| 58 |
+
def create(
|
| 59 |
+
cls,
|
| 60 |
+
level_tensor: torch.Tensor,
|
| 61 |
+
sample_predictions_tensor: torch.Tensor,
|
| 62 |
+
tokenizer,
|
| 63 |
+
prompter: Optional[Prompter] = None,
|
| 64 |
+
) -> SampleOutput:
|
| 65 |
+
# batch = 1
|
| 66 |
+
level = None
|
| 67 |
+
img = None
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
level = view_level(level_tensor, tokenizer)
|
| 71 |
+
img = convert_level_to_png(level)[0]
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(
|
| 74 |
+
f"Failed to generate string or image representation for full level! Got error {e}"
|
| 75 |
+
)
|
| 76 |
+
level = None
|
| 77 |
+
img = None
|
| 78 |
+
try:
|
| 79 |
+
sample_predictions_str = view_level(sample_predictions_tensor, tokenizer)
|
| 80 |
+
sample_predictions_img = convert_level_to_png(sample_predictions_str)[0]
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(
|
| 83 |
+
f"Failed to generate string or image representation for sampled predictions! Got error {e}"
|
| 84 |
+
)
|
| 85 |
+
sample_predictions_str = None
|
| 86 |
+
sample_predictions_img = None
|
| 87 |
+
|
| 88 |
+
prompt = None
|
| 89 |
+
if prompter is not None:
|
| 90 |
+
prompt = prompter(level_tensor)[0]
|
| 91 |
+
|
| 92 |
+
return SampleOutput(
|
| 93 |
+
level,
|
| 94 |
+
prompt,
|
| 95 |
+
img,
|
| 96 |
+
sample_predictions_str,
|
| 97 |
+
sample_predictions_img,
|
| 98 |
+
level_tensor,
|
| 99 |
+
sample_predictions_tensor,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
@classmethod
|
| 103 |
+
def from_level_predictions(
|
| 104 |
+
cls,
|
| 105 |
+
level: torch.Tensor,
|
| 106 |
+
sample_predictions: torch.Tensor,
|
| 107 |
+
tokenizer,
|
| 108 |
+
prompter: Optional[Prompter] = None,
|
| 109 |
+
) -> Union[SampleOutput, List[SampleOutput]]:
|
| 110 |
+
level_tensor = trim_level(level).squeeze().detach().cpu()
|
| 111 |
+
sample_predictions_tensor = (
|
| 112 |
+
trim_level(sample_predictions).squeeze().detach().cpu()
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if len(level_tensor.shape) == 1:
|
| 116 |
+
return SampleOutput.create(
|
| 117 |
+
level_tensor, sample_predictions_tensor, tokenizer, prompter
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
out = []
|
| 121 |
+
for _level_tensor, _sample_predictions_tensor in zip(
|
| 122 |
+
level_tensor, sample_predictions_tensor
|
| 123 |
+
):
|
| 124 |
+
sample_output = SampleOutput.create(
|
| 125 |
+
_level_tensor, _sample_predictions_tensor, tokenizer, prompter
|
| 126 |
+
)
|
| 127 |
+
out.append(sample_output)
|
| 128 |
+
return out
|
| 129 |
+
|
| 130 |
+
def save(self, filename: str) -> str:
|
| 131 |
+
save_level(self.level, filename)
|
| 132 |
+
|
| 133 |
+
@classmethod
|
| 134 |
+
def load(cls, filename: str) -> SampleOutput:
|
| 135 |
+
level = load_level(filename)
|
| 136 |
+
return SampleOutput(level=level)
|
| 137 |
+
|
| 138 |
+
def play(self, game="mario", level_idx=None, dataset_path=None):
|
| 139 |
+
"""
|
| 140 |
+
Play the level using the specified game engine.
|
| 141 |
+
game: "mario" (default) or "loderunner"
|
| 142 |
+
"""
|
| 143 |
+
if game == "loderunner":
|
| 144 |
+
import tempfile, json
|
| 145 |
+
# Convert self.level (list of strings) to Lode Runner JSON format
|
| 146 |
+
scene = [[c for c in row] for row in self.level]
|
| 147 |
+
lr_json = [{
|
| 148 |
+
"scene": scene,
|
| 149 |
+
"caption": ""
|
| 150 |
+
}]
|
| 151 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
|
| 152 |
+
json.dump(lr_json, tmp)
|
| 153 |
+
tmp_path = tmp.name
|
| 154 |
+
import sys, os
|
| 155 |
+
#sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
| 156 |
+
from LodeRunner.loderunner import main
|
| 157 |
+
tmp_path = tmp_path if dataset_path is None else dataset_path
|
| 158 |
+
print(f"Playing Lode Runner level interactively -- {tmp_path}!")
|
| 159 |
+
main.play_lr_level(tmp_path, level_index=level_idx if level_idx is not None else 1)
|
| 160 |
+
else:
|
| 161 |
+
if self.use_snes_graphics:
|
| 162 |
+
simulator = CustomSimulator(level=self.level, jar_path="MarioEval.jar")
|
| 163 |
+
else:
|
| 164 |
+
simulator = CustomSimulator(level=self.level, jar_path="NESMarioEval.jar")
|
| 165 |
+
simulator.interactive()
|
| 166 |
+
|
| 167 |
+
def run_astar(self, render=True):
|
| 168 |
+
if self.use_snes_graphics:
|
| 169 |
+
simulator = CustomSimulator(level=self.level, jar_path="MarioEval.jar")
|
| 170 |
+
else:
|
| 171 |
+
simulator = CustomSimulator(level=self.level, jar_path="NESMarioEval.jar")
|
| 172 |
+
return simulator.astar(render)
|
| 173 |
+
|
| 174 |
+
class CustomSimulator:
|
| 175 |
+
"""
|
| 176 |
+
The classic Mario simulator used by MarioGPT is generally,
|
| 177 |
+
better, but it doesn't return any information about
|
| 178 |
+
Mario's performance. The main point of this simulator
|
| 179 |
+
is that information about the performance of the agent
|
| 180 |
+
is printed to the console (though I still need a way
|
| 181 |
+
to caption and return that information)
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
def __init__(self, level, jar_path="MarioEval.jar"):
|
| 185 |
+
while len(level) > 15:
|
| 186 |
+
level.pop(0)
|
| 187 |
+
# For some reason, my older A* agent
|
| 188 |
+
# crashes on Mario levels with 16 rows or more
|
| 189 |
+
|
| 190 |
+
self.level = level
|
| 191 |
+
self.jar_path = jar_path
|
| 192 |
+
|
| 193 |
+
def interactive(self):
|
| 194 |
+
t = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
|
| 195 |
+
save_level(self.level, t.name)
|
| 196 |
+
print(f"Playing level interactively -- {t.name}!")
|
| 197 |
+
_ = subprocess.run(
|
| 198 |
+
["java", "-jar", self.jar_path, "human", t.name, "human"],
|
| 199 |
+
stdout=subprocess.PIPE,
|
| 200 |
+
stderr=subprocess.PIPE,
|
| 201 |
+
)
|
| 202 |
+
t.close()
|
| 203 |
+
os.unlink(t.name)
|
| 204 |
+
|
| 205 |
+
def astar(self, render: bool = True):
|
| 206 |
+
t = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
|
| 207 |
+
save_level(self.level, t.name)
|
| 208 |
+
print(f"Running Astar agent on level! -- {t.name}")
|
| 209 |
+
render_str = "human" if render else "norender"
|
| 210 |
+
result = subprocess.run(
|
| 211 |
+
["java", "-jar", self.jar_path, "astar", t.name, render_str],
|
| 212 |
+
stdout=subprocess.PIPE,
|
| 213 |
+
stderr=subprocess.PIPE,
|
| 214 |
+
)
|
| 215 |
+
t.close()
|
| 216 |
+
os.unlink(t.name)
|
| 217 |
+
# Combine stdout and stderr, decode to string, and return
|
| 218 |
+
output = result.stdout.decode("utf-8") + result.stderr.decode("utf-8")
|
| 219 |
+
return output
|
| 220 |
+
|
| 221 |
+
def save_level(level: List[str], filename: str):
|
| 222 |
+
concatenated = "\n".join(level)
|
| 223 |
+
with open(filename, "w") as f:
|
| 224 |
+
f.write(concatenated)
|
| 225 |
+
return filename
|
| 226 |
+
|
| 227 |
+
class GPTSampler:
|
| 228 |
+
def __init__(
|
| 229 |
+
self,
|
| 230 |
+
mario_lm: BaseMarioLM,
|
| 231 |
+
temperature: float = 2.0,
|
| 232 |
+
top_k: int = 16,
|
| 233 |
+
context_len: int = 700,
|
| 234 |
+
use_tqdm: bool = False,
|
| 235 |
+
use_argmax: bool = False,
|
| 236 |
+
):
|
| 237 |
+
self.mario_lm = mario_lm
|
| 238 |
+
self.temperature = temperature
|
| 239 |
+
self.top_k = top_k
|
| 240 |
+
self.context_len = context_len
|
| 241 |
+
self.use_tqdm = use_tqdm
|
| 242 |
+
self.use_argmax = use_argmax
|
| 243 |
+
self.logits_processor = LogitsProcessorList()
|
| 244 |
+
self.logits_warper = LogitsProcessorList(
|
| 245 |
+
[
|
| 246 |
+
TopKLogitsWarper(top_k), # number of characters
|
| 247 |
+
TemperatureLogitsWarper(temperature),
|
| 248 |
+
]
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
@property
|
| 252 |
+
def device(self) -> torch.device:
|
| 253 |
+
return self.mario_lm.device
|
| 254 |
+
|
| 255 |
+
def step(
|
| 256 |
+
self,
|
| 257 |
+
seed: torch.Tensor,
|
| 258 |
+
encoder_hidden_states: torch.Tensor,
|
| 259 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 260 |
+
with torch.no_grad():
|
| 261 |
+
attention_mask = torch.ones_like(seed).to(seed.device)
|
| 262 |
+
input_ids = seed
|
| 263 |
+
out = self.mario_lm.lm(
|
| 264 |
+
input_ids=input_ids,
|
| 265 |
+
attention_mask=attention_mask,
|
| 266 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 267 |
+
token_type_ids=None,
|
| 268 |
+
)
|
| 269 |
+
logits = out.logits.detach()
|
| 270 |
+
if len(logits.shape) == 2:
|
| 271 |
+
logits = logits.view(1, 1, -1)
|
| 272 |
+
next_token_logits = logits[:, -1, :]
|
| 273 |
+
|
| 274 |
+
if self.use_argmax:
|
| 275 |
+
next_tokens = next_token_logits.argmax(-1)
|
| 276 |
+
else:
|
| 277 |
+
next_token_scores = self.logits_processor(input_ids, next_token_logits)
|
| 278 |
+
next_token_scores = self.logits_warper(input_ids, next_token_scores)
|
| 279 |
+
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
| 280 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 281 |
+
return next_tokens, encoder_hidden_states
|
| 282 |
+
|
| 283 |
+
def sample(
|
| 284 |
+
self,
|
| 285 |
+
seed: Union[Optional[torch.Tensor], Optional[SampleOutput]] = None,
|
| 286 |
+
prompts: Optional[List[str]] = None,
|
| 287 |
+
num_steps: int = 1,
|
| 288 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 289 |
+
return_tensor: bool = False,
|
| 290 |
+
):
|
| 291 |
+
self.mario_lm.eval()
|
| 292 |
+
context_len = self.context_len - 28
|
| 293 |
+
with torch.no_grad():
|
| 294 |
+
if seed is None:
|
| 295 |
+
seed = self.mario_lm.generate_seed(1, batch_size=len(prompts)).to(
|
| 296 |
+
self.device
|
| 297 |
+
)
|
| 298 |
+
out_tensor = seed.to(self.device)
|
| 299 |
+
elif isinstance(seed, SampleOutput):
|
| 300 |
+
out_tensor = seed.level_tensor.to(self.device).squeeze()
|
| 301 |
+
else:
|
| 302 |
+
out_tensor = seed.to(self.device).squeeze()
|
| 303 |
+
if len(out_tensor.shape) < 2:
|
| 304 |
+
# if we pass in a single seed vector, then we repeat for each prompt
|
| 305 |
+
# Otherwise, we treat inputs as separate seed-prompt pairs
|
| 306 |
+
out_tensor = out_tensor.view(1, -1).repeat(len(prompts), 1)
|
| 307 |
+
if encoder_hidden_states is None:
|
| 308 |
+
if prompts is not None:
|
| 309 |
+
encoder_hidden_states = torch.stack(
|
| 310 |
+
[
|
| 311 |
+
self.mario_lm.prompter.output_hidden(prompt)
|
| 312 |
+
for prompt in prompts
|
| 313 |
+
]
|
| 314 |
+
)
|
| 315 |
+
else:
|
| 316 |
+
encoder_hidden_states = torch.stack(
|
| 317 |
+
[
|
| 318 |
+
self.mario_lm.prompter(sample_prompt=True)[1]
|
| 319 |
+
for _ in range(seed.shape[0])
|
| 320 |
+
]
|
| 321 |
+
)
|
| 322 |
+
encoder_hidden_states = encoder_hidden_states.to(
|
| 323 |
+
self.device
|
| 324 |
+
) # b x 1 x hidden_dim
|
| 325 |
+
encoder_hidden_states = encoder_hidden_states.view(
|
| 326 |
+
out_tensor.shape[0], 1, -1
|
| 327 |
+
)
|
| 328 |
+
if not self.use_tqdm:
|
| 329 |
+
bar = np.arange(num_steps)
|
| 330 |
+
else:
|
| 331 |
+
bar = tqdm(np.arange(num_steps))
|
| 332 |
+
with torch.no_grad():
|
| 333 |
+
for i in bar:
|
| 334 |
+
inp = out_tensor * 1
|
| 335 |
+
if len(out_tensor.shape) > 0 and out_tensor.shape[-1] > context_len:
|
| 336 |
+
diff = inp.shape[-1] % 14 # height of mario level
|
| 337 |
+
ctx = context_len + diff
|
| 338 |
+
inp = inp[:, -ctx:] * 1
|
| 339 |
+
next_tokens, encoder_hidden_states = self.step(
|
| 340 |
+
inp,
|
| 341 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 342 |
+
)
|
| 343 |
+
out_tensor = torch.cat(
|
| 344 |
+
[out_tensor, next_tokens.unsqueeze(-1)], dim=-1
|
| 345 |
+
)
|
| 346 |
+
if self.use_tqdm:
|
| 347 |
+
bar.set_description(
|
| 348 |
+
f"shape: {inp.shape}, {out_tensor.shape} first: {inp[0][0]}, last: {out_tensor[0][-1]}"
|
| 349 |
+
)
|
| 350 |
+
if self.use_tqdm:
|
| 351 |
+
bar.close()
|
| 352 |
+
sample_out = SampleOutput.from_level_predictions(
|
| 353 |
+
out_tensor,
|
| 354 |
+
out_tensor[:, -num_steps:],
|
| 355 |
+
self.mario_lm.tokenizer,
|
| 356 |
+
self.mario_lm.prompter,
|
| 357 |
+
)
|
| 358 |
+
self.mario_lm.train()
|
| 359 |
+
if return_tensor:
|
| 360 |
+
return sample_out, out_tensor
|
| 361 |
+
return sample_out
|
| 362 |
+
|
| 363 |
+
def __call__(self, *args, **kwargs):
|
| 364 |
+
return self.sample(*args, **kwargs)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class BertSampler:
|
| 368 |
+
def __init__(
|
| 369 |
+
self,
|
| 370 |
+
mario_lm: BaseMarioLM,
|
| 371 |
+
temperature: float = 2.0,
|
| 372 |
+
top_k: int = 16,
|
| 373 |
+
context_len: int = 448,
|
| 374 |
+
mask_proportion: float = 0.16,
|
| 375 |
+
):
|
| 376 |
+
self.mario_lm = mario_lm
|
| 377 |
+
self.temperature = temperature
|
| 378 |
+
self.top_k = top_k
|
| 379 |
+
self.logits_processor = LogitsProcessorList()
|
| 380 |
+
self.logits_warper = LogitsProcessorList(
|
| 381 |
+
[
|
| 382 |
+
TopKLogitsWarper(top_k), # number of characters
|
| 383 |
+
TemperatureLogitsWarper(temperature),
|
| 384 |
+
]
|
| 385 |
+
)
|
| 386 |
+
self.context_len = context_len
|
| 387 |
+
self.mask_proportion = mask_proportion
|
| 388 |
+
self.mask_portion = int(self.context_len * self.mask_proportion)
|
| 389 |
+
self.mask_portion = self.mask_portion - self.mask_portion % 14 + 14
|
| 390 |
+
|
| 391 |
+
@property
|
| 392 |
+
def device(self) -> torch.device:
|
| 393 |
+
return self.mario_lm.device
|
| 394 |
+
|
| 395 |
+
def get_context(self, input_ids, mask_indices):
|
| 396 |
+
start_idx = mask_indices[0]
|
| 397 |
+
end_idx = mask_indices[-1]
|
| 398 |
+
|
| 399 |
+
if input_ids.shape[-1] <= self.context_len:
|
| 400 |
+
clipped = input_ids.shape[-1] % 14
|
| 401 |
+
input_ids = input_ids[:clipped]
|
| 402 |
+
|
| 403 |
+
portion = (self.context_len - self.mask_portion) / 2
|
| 404 |
+
|
| 405 |
+
remainder = 0
|
| 406 |
+
left = start_idx - portion
|
| 407 |
+
if left < 0:
|
| 408 |
+
remainder = -1 * left
|
| 409 |
+
|
| 410 |
+
right = end_idx + portion + remainder
|
| 411 |
+
|
| 412 |
+
return input_ids[left:right]
|
| 413 |
+
|
| 414 |
+
def sample(
|
| 415 |
+
self,
|
| 416 |
+
seed: Union[torch.Tensor, SampleOutput],
|
| 417 |
+
mask: torch.Tensor,
|
| 418 |
+
return_tensor: bool = False,
|
| 419 |
+
):
|
| 420 |
+
self.mario_lm.eval()
|
| 421 |
+
mask_indices = mask.nonzero()
|
| 422 |
+
input_ids = seed
|
| 423 |
+
if isinstance(seed, SampleOutput):
|
| 424 |
+
input_ids = seed.level_tensor.to(self.device).squeeze()
|
| 425 |
+
|
| 426 |
+
input_id_list = []
|
| 427 |
+
for i in range(input_ids.shape[0]):
|
| 428 |
+
input_id = input_ids[i]
|
| 429 |
+
mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
|
| 430 |
+
input_id = self.get_context(input_id, mask_index)
|
| 431 |
+
input_id_list.append(input_id)
|
| 432 |
+
input_ids = torch.stack(input_ids, dim=0).to(self.device)
|
| 433 |
+
|
| 434 |
+
attention_mask = torch.ones_like(input_ids).to(seed.device)
|
| 435 |
+
|
| 436 |
+
if len(input_ids.shape) < 2:
|
| 437 |
+
# if we pass in a single seed vector, then we repeat for each prompt
|
| 438 |
+
# Otherwise, we treat inputs as separate seed-prompt pairs
|
| 439 |
+
input_ids = input_ids.view(1, -1)
|
| 440 |
+
|
| 441 |
+
out = self.mario_lm.lm(
|
| 442 |
+
input_ids=input_ids,
|
| 443 |
+
attention_mask=attention_mask,
|
| 444 |
+
token_type_ids=None,
|
| 445 |
+
)
|
| 446 |
+
logits = out.logits.detach()
|
| 447 |
+
if len(logits.shape) == 2:
|
| 448 |
+
logits = logits.view(1, 1, -1)
|
| 449 |
+
|
| 450 |
+
if self.use_argmax:
|
| 451 |
+
tokens = logits.argmax(-1)
|
| 452 |
+
else:
|
| 453 |
+
tokens_scores = self.logits_processor(input_ids, tokens)
|
| 454 |
+
tokens_scores = self.logits_warper(input_ids, tokens_scores)
|
| 455 |
+
probs = torch.nn.functional.softmax(tokens_scores, dim=-1)
|
| 456 |
+
tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 457 |
+
|
| 458 |
+
out = input_ids.detach()
|
| 459 |
+
|
| 460 |
+
for i in range(input_ids.shape[0]):
|
| 461 |
+
mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
|
| 462 |
+
out[i, mask_index] = tokens[i, mask_index].detach()
|
| 463 |
+
|
| 464 |
+
sample_out = SampleOutput.from_level_predictions(
|
| 465 |
+
out,
|
| 466 |
+
tokens,
|
| 467 |
+
self.mario_lm.tokenizer,
|
| 468 |
+
self.mario_lm.prompter,
|
| 469 |
+
)
|
| 470 |
+
self.mario_lm.train()
|
| 471 |
+
if return_tensor:
|
| 472 |
+
return sample_out, tokens
|
| 473 |
+
return sample_out
|