schrum2 commited on
Commit
e7ac06f
·
verified ·
1 Parent(s): d161b39

Deleting directories, moving files into root

Browse files
captions/caption_match.py DELETED
@@ -1,247 +0,0 @@
1
- from create_ascii_captions import assign_caption
2
-
3
- # Quantity order for scoring partial matches
4
- QUANTITY_TERMS = ["one", "two", "a few", "several", "many"]
5
-
6
- # Topics to compare
7
- TOPIC_KEYWORDS = [
8
- #"giant gap", # I think all gaps are subsumed by the floor topic
9
- "floor", "ceiling",
10
- "broken pipe", "upside down pipe", "pipe",
11
- "coin line", "coin",
12
- "platform", "tower", #"wall",
13
- "broken cannon", "cannon",
14
- "ascending staircase", "descending staircase",
15
- "rectangular",
16
- "irregular",
17
- "question block", "loose block",
18
- "enem" # catch "enemy"/"enemies"
19
- ]
20
-
21
- # Need list because the order matters
22
- KEYWORD_TO_NEGATED_PLURAL = [
23
- (" broken pipe.", ""), # If not the first phrase
24
- ("broken pipe. ", ""), # If the first phrase (after removing all others)
25
- (" broken cannon.", ""), # If not the first phrase
26
- ("broken cannon. ", ""), # If the first phrase (after removing all others)
27
- ("pipe", "pipes"),
28
- ("cannon", "cannons"),
29
- ("platform", "platforms"),
30
- ("tower", "towers"),
31
- ("staircase", "staircases"),
32
- ("enem", "enemies"),
33
- ("rectangular", "rectangular block clusters"),
34
- ("irregular", "irregular block clusters"),
35
- ("coin line", "coin lines"),
36
- ("coin.", "coins."), # Need period to avoid matching "coin line"
37
- ("question block", "question blocks"),
38
- ("loose block", "loose blocks")
39
- ]
40
-
41
- BROKEN_TOPICS = 2 # Number of topics that are considered "broken" (e.g., "broken pipe", "broken cannon")
42
-
43
- # Plural normalization map (irregulars)
44
- PLURAL_EXCEPTIONS = {
45
- "enemies": "enemy",
46
- }
47
-
48
- def normalize_plural(phrase):
49
- # Normalize known irregular plurals
50
- for plural, singular in PLURAL_EXCEPTIONS.items():
51
- phrase = phrase.replace(plural, singular)
52
-
53
- # Normalize regular plurals (basic "s" endings)
54
- words = phrase.split()
55
- normalized_words = []
56
- for word in words:
57
- if word.endswith('s') and not word.endswith('ss'): # avoid "class", "boss"
58
- singular = word[:-1]
59
- normalized_words.append(singular)
60
- else:
61
- normalized_words.append(word)
62
- return ' '.join(normalized_words)
63
-
64
- def extract_phrases(caption, debug=False):
65
- phrases = [phrase.strip() for phrase in caption.split('.') if phrase.strip()]
66
- topic_to_phrase = {}
67
- already_matched_phrases = set() # Track phrases that have been matched
68
-
69
- for topic in TOPIC_KEYWORDS:
70
- matching_phrases = []
71
-
72
- for p in phrases:
73
- # Only consider phrases that haven't been matched to longer topics
74
- if topic in p and p not in already_matched_phrases:
75
- matching_phrases.append(p)
76
-
77
- if matching_phrases:
78
- # Filter out "no ..." phrases as equivalent to absence
79
- phrase = matching_phrases[0]
80
- if phrase.lower().startswith("no "):
81
- topic_to_phrase[topic] = None
82
- if debug:
83
- print(f"[Extract] Topic '{topic}': detected 'no ...', treating as None")
84
- else:
85
- topic_to_phrase[topic] = phrase
86
- already_matched_phrases.add(phrase) # Mark this phrase as matched
87
- if debug:
88
- print(f"[Extract] Topic '{topic}': found phrase '{phrase}'")
89
- else:
90
- topic_to_phrase[topic] = None
91
- if debug:
92
- print(f"[Extract] Topic '{topic}': no phrase found")
93
-
94
- return topic_to_phrase
95
-
96
- def quantity_score(phrase1, phrase2, debug=False):
97
- def find_quantity(phrase):
98
- for term in QUANTITY_TERMS:
99
- if term in phrase:
100
- return term
101
- return None
102
-
103
- qty1 = find_quantity(phrase1)
104
- qty2 = find_quantity(phrase2)
105
-
106
- if debug:
107
- print(f"[Quantity] Comparing quantities: '{qty1}' vs. '{qty2}'")
108
-
109
- if qty1 and qty2:
110
- idx1 = QUANTITY_TERMS.index(qty1)
111
- idx2 = QUANTITY_TERMS.index(qty2)
112
- diff = abs(idx1 - idx2)
113
- max_diff = len(QUANTITY_TERMS) - 1
114
- score = 1.0 - (diff / max_diff)
115
- if debug:
116
- print(f"[Quantity] Quantity indices: {idx1} vs. {idx2}, diff: {diff}, score: {score:.2f}")
117
- return score
118
- if debug:
119
- print("[Quantity] At least one quantity missing, assigning partial score 0.1")
120
- return 0.1
121
-
122
- def compare_captions(correct_caption, generated_caption, debug=False, return_matches=False):
123
- correct_phrases = extract_phrases(correct_caption, debug=debug)
124
- generated_phrases = extract_phrases(generated_caption, debug=debug)
125
-
126
- total_score = 0.0
127
- num_topics = len(TOPIC_KEYWORDS)
128
-
129
- exact_matches = []
130
- partial_matches = []
131
- excess_phrases = []
132
-
133
- if debug:
134
- print("\n--- Starting Topic Comparison ---\n")
135
-
136
- for topic in TOPIC_KEYWORDS:
137
- correct = correct_phrases[topic]
138
- generated = generated_phrases[topic]
139
-
140
- if debug:
141
- print(f"[Topic: {topic}] Correct: {correct} | Generated: {generated}")
142
-
143
- if correct is None and generated is None:
144
- total_score += 1.0
145
- if debug:
146
- print(f"[Topic: {topic}] Both None — full score: 1.0\n")
147
- elif correct is None or generated is None:
148
- total_score += -1.0
149
- if generated is not None: # Considered an excess phrase
150
- excess_phrases.append(generated)
151
- if debug:
152
- print(f"[Topic: {topic}] One is None — penalty: -1.0\n")
153
- else:
154
- # Normalize pluralization before comparison
155
- norm_correct = normalize_plural(correct)
156
- norm_generated = normalize_plural(generated)
157
-
158
- if debug:
159
- print(f"[Topic: {topic}] Normalized: Correct: '{norm_correct}' | Generated: '{norm_generated}'")
160
-
161
- if norm_correct == norm_generated:
162
- total_score += 1.0
163
- exact_matches.append(generated)
164
- if debug:
165
- print(f"[Topic: {topic}] Exact match — score: 1.0\n")
166
- elif any(term in norm_correct for term in QUANTITY_TERMS) and any(term in norm_generated for term in QUANTITY_TERMS):
167
- qty_score = quantity_score(norm_correct, norm_generated, debug=debug)
168
- total_score += qty_score
169
- partial_matches.append(generated)
170
- if debug:
171
- print(f"[Topic: {topic}] Quantity-based partial score: {qty_score:.2f}\n")
172
- else:
173
- total_score += 0.1
174
- partial_matches.append(generated)
175
- if debug:
176
- print(f"[Topic: {topic}] Partial match (topic overlap) — score: 0.1\n")
177
-
178
- if debug:
179
- print(f"[Topic: {topic}] Current total score: {total_score:.4f}\n")
180
-
181
- if debug:
182
- print("total_score before normalization:", total_score)
183
- print(f"Number of topics: {num_topics}")
184
-
185
- final_score = total_score / num_topics
186
- if debug:
187
- print(f"--- Final score: {final_score:.4f} ---\n")
188
-
189
- if return_matches:
190
- return final_score, exact_matches, partial_matches, excess_phrases
191
-
192
- return final_score
193
-
194
- def process_scene_segments(scene, segment_width, prompt, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence, verbose=False):
195
- """
196
- Process a scene by partitioning it into segments, assigning captions, and computing comparison scores.
197
-
198
- Args:
199
- scene (list): The scene to process, represented as a 2D list.
200
- segment_width (int): The width of each segment.
201
- prompt (str): The prompt to compare captions against.
202
- id_to_char (dict): Mapping from tile IDs to characters.
203
- char_to_id (dict): Mapping from characters to tile IDs.
204
- tile_descriptors (dict): Descriptions of individual tile types.
205
- describe_locations (bool): Whether to include location descriptions in captions.
206
- describe_absence (bool): Whether to indicate absence of items in captions.
207
- verbose (bool): If True, print captions and scores for each segment.
208
-
209
- Returns:
210
- tuple: A tuple containing the average comparison score, captions for each segment, and scores for each segment.
211
- """
212
- # Partition the scene into segments of the specified width
213
- segments = [
214
- [row[i:i+segment_width] for row in scene] # Properly slice each row of the scene
215
- for i in range(0, len(scene[0]), segment_width)
216
- ]
217
-
218
- # Assign captions and compute scores for each segment
219
- segment_scores = []
220
- segment_captions = []
221
- for idx, segment in enumerate(segments):
222
- segment_caption = assign_caption(segment, id_to_char, char_to_id, tile_descriptors, describe_locations, describe_absence)
223
- segment_score = compare_captions(prompt, segment_caption)
224
- segment_scores.append(segment_score)
225
- segment_captions.append(segment_caption)
226
-
227
- if verbose:
228
- print(f"Segment {idx + 1} caption: {segment_caption}")
229
- print(f"Segment {idx + 1} comparison score: {segment_score}")
230
-
231
- # Compute the average comparison score
232
- average_score = sum(segment_scores) / len(segment_scores) if segment_scores else 0
233
-
234
- if verbose:
235
- print(f"Average comparison score across all segments: {average_score}")
236
-
237
- return average_score, segment_captions, segment_scores
238
-
239
- if __name__ == '__main__':
240
-
241
- ref = "floor with one gap. two enemies. one platform. one tower."
242
- gen = "giant gap with one chunk of floor. two enemies. one platform. one tower."
243
-
244
- score = compare_captions(ref, gen, debug=True)
245
- print(f"Should be: {ref}")
246
- print(f" but was: {gen}")
247
- print(f"Score: {score}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
captions/evaluate_caption_order_tolerance.py DELETED
@@ -1,288 +0,0 @@
1
- import argparse
2
- import itertools
3
- import os
4
- import random
5
- from collections import defaultdict
6
- import sys, os
7
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8
- import util.common_settings as common_settings # adjust import if needed
9
- from level_dataset import LevelDataset, visualize_samples, colors, mario_tiles # adjust import if needed
10
- from torch.utils.data import DataLoader
11
- from evaluate_caption_adherence import calculate_caption_score_and_samples # adjust import if needed
12
- import matplotlib.pyplot as plt
13
- import matplotlib
14
- import json
15
- from tqdm import tqdm
16
-
17
- import numpy as np
18
- import torch
19
- from tqdm import tqdm
20
-
21
- from captions.util import extract_tileset
22
- from models.pipeline_loader import get_pipeline
23
-
24
-
25
- def parse_args():
26
- parser = argparse.ArgumentParser(description="Evaluate caption order tolerance for a diffusion model.")
27
- parser.add_argument("--model_path", type=str, required=True, help="Path to the trained diffusion model")
28
- parser.add_argument("--caption", type=str, required=False, default=None, help="Caption to evaluate, phrases separated by periods")
29
- parser.add_argument("--tileset", type=str, help="Path to the tileset JSON file")
30
- #parser.add_argument("--json", type=str, default="datasets\\Test_for_caption_order_tolerance.json", help="Path to dataset json file")
31
- #parser.add_argument("--json", type=str, default="datasets\\SMB1_LevelsAndCaptions-regular-test.json", help="Path to dataset json file")
32
- parser.add_argument("--json", type=str, default="datasets\\Mar1and2_LevelsAndCaptions-regular.json", help="Path to dataset json file")
33
- #parser.add_argument("--trials", type=int, default=3, help="Number of times to evaluate each caption permutation")
34
- parser.add_argument("--inference_steps", type=int, default=common_settings.NUM_INFERENCE_STEPS)
35
- parser.add_argument("--guidance_scale", type=float, default=common_settings.GUIDANCE_SCALE)
36
- parser.add_argument("--seed", type=int, default=42)
37
- parser.add_argument("--game", type=str, choices=["Mario", "LR"], default="Mario", help="Game to evaluate (Mario or Lode Runner)")
38
- parser.add_argument("--describe_absence", action="store_true", default=False, help="Indicate when there are no occurrences of an item or structure")
39
- parser.add_argument("--save_as_json", action="store_true", help="Save generated levels as JSON")
40
- parser.add_argument("--output_dir", type=str, default="visualizations", help="Output directory if not comparing checkpoints (subdir of model directory)")
41
- parser.add_argument("--max_permutations", type=int, default=5, help="Maximum amount of permutations that can be made per caption")
42
- return parser.parse_args()
43
-
44
-
45
- def setup_environment(seed):
46
- device = "cuda" if torch.cuda.is_available() else "cpu"
47
- random.seed(seed)
48
- np.random.seed(seed)
49
- torch.manual_seed(seed)
50
- if torch.cuda.is_available():
51
- torch.cuda.manual_seed_all(seed)
52
- return device
53
-
54
- def load_captions_from_json(json_path):
55
- with open(json_path, 'r', encoding='utf-8') as f:
56
- data = json.load(f)
57
- # If the JSON is a list of dicts with a "caption" key
58
- captions = [entry["caption"] for entry in data if "caption" in entry]
59
- return captions
60
-
61
- def creation_of_parameters(caption, max_permutations):
62
- args = parse_args()
63
- device = setup_environment(args.seed)
64
-
65
- if args.game == "Mario":
66
- num_tiles = common_settings.MARIO_TILE_COUNT
67
- tileset = '..\TheVGLC\Super Mario Bros\smb.json'
68
- elif args.game == "LR":
69
- num_tiles = common_settings.LR_TILE_COUNT
70
- tileset = '..\TheVGLC\Lode Runner\Loderunner.json'
71
- else:
72
- raise ValueError(f"Unknown game: {args.game}")
73
-
74
- # Load pipeline
75
- pipe = get_pipeline(args.model_path).to(device)
76
-
77
- # Load tile metadata
78
- tile_chars, id_to_char, char_to_id, tile_descriptors = extract_tileset(tileset)
79
-
80
- perm_captions = []
81
- if isinstance(caption, list):
82
- # captions is a list of caption strings
83
- phrases_per_caption = [
84
- [p.strip() for p in cap.split('.') if p.strip()]
85
- for cap in caption
86
- ]
87
- permutations = []
88
- for phrases in phrases_per_caption:
89
- perms = list(itertools.permutations(phrases))
90
- if len(perms) > max_permutations:
91
- perms = random.sample(perms, max_permutations)
92
- permutations.append(perms)
93
- perm_captions = ['.'.join(perm) + '.' for perms in permutations for perm in perms]
94
- elif isinstance(caption, str):
95
- # Split caption into phrases and get all permutations
96
- phrase = [p.strip() for p in caption.split('.') if p.strip()]
97
- permutations_cap = []
98
- perms = list(itertools.permutations(phrase))
99
- if len(perms) > max_permutations:
100
- perms = random.sample(perms, max_permutations)
101
- permutations_cap.append(perms)
102
-
103
- perm_captions = ['.'.join(perm) + '.' for perms in permutations_cap for perm in perms]
104
-
105
- # Create a list of dicts as expected by LevelDataset
106
- caption_data = [{"scene": None, "caption": cap} for cap in perm_captions]
107
-
108
- # Initialize dataset
109
- dataset = LevelDataset(
110
- data_as_list=caption_data,
111
- shuffle=False,
112
- mode="text",
113
- augment=False,
114
- num_tiles=common_settings.MARIO_TILE_COUNT,
115
- negative_captions=False,
116
- block_embeddings=None
117
- )
118
-
119
- # Create dataloader
120
- dataloader = DataLoader(
121
- dataset,
122
- batch_size=min(16, len(perm_captions)),
123
- shuffle=False,
124
- num_workers=4,
125
- drop_last=False,
126
- persistent_workers=True
127
- )
128
-
129
-
130
- return pipe, device, id_to_char, char_to_id, tile_descriptors, num_tiles, dataloader, perm_captions, caption_data
131
-
132
- def statistics_of_captions(captions, dataloader, compare_all_scores, pipe=None, device=None, id_to_char=None, char_to_id=None, tile_descriptors=None, num_tiles=None):
133
- """
134
- Calculate statistics of the captions.
135
- Returns average, standard deviation, minimum, maximum, and median of caption scores.
136
- """
137
- args = parse_args()
138
- if not captions:
139
- print("No captions found in the provided JSON file.")
140
- return
141
- print(f"\nLoaded {len(captions)} captions from {args.json}")
142
-
143
-
144
- avg_score = np.mean(compare_all_scores)
145
- std_dev_score = np.std(compare_all_scores)
146
- min_score = np.min(compare_all_scores)
147
- max_score = np.max(compare_all_scores)
148
- median_score = np.median(compare_all_scores)
149
-
150
- print("\n-----Scores for each caption permutation-----")
151
- for i, score in enumerate(compare_all_scores):
152
- print(f"Scores for caption {i + 1}:", score)
153
-
154
- print("\n-----Statistics of captions-----")
155
- print(f"Average score: {avg_score:.4f}")
156
- print(f"Standard deviation: {std_dev_score:.4f}")
157
- print(f"Minimum score: {min_score:.4f}")
158
- print(f"Maximum score: {max_score:.4f}")
159
- print(f"Median score: {median_score:.4f}")
160
-
161
- return compare_all_scores, avg_score, std_dev_score, min_score, max_score, median_score
162
-
163
- def main():
164
- args = parse_args()
165
- if args.caption is None or args.caption == "":
166
- caption = load_captions_from_json(args.json)
167
- else:
168
- caption = args.caption
169
- #caption = ("many pipes. many coins. , many enemies. many blocks. , many platforms. many question blocks.").split(',')
170
-
171
- all_scores = []
172
- all_avg_scores = []
173
- all_std_dev_scores = []
174
- all_min_scores = []
175
- all_max_scores = []
176
- all_median_scores = []
177
- all_captions = [item.strip() for s in caption for item in s.split(",")]
178
-
179
- one_caption = []
180
- count = 0
181
-
182
- output_jsonl_path = os.path.join(args.output_dir, "evaluation_caption_order_results.jsonl")
183
- with open(output_jsonl_path, "w") as f:
184
- for cap in all_captions:
185
- one_caption = cap
186
-
187
- # Initialize dataset
188
- pipe, device, id_to_char, char_to_id, tile_descriptors, num_tiles, dataloader, perm_caption, caption_data = creation_of_parameters(one_caption, args.max_permutations)
189
- if not pipe:
190
- print("Failed to create pipeline.")
191
- return
192
-
193
- avg_score, all_samples, all_prompts, compare_all_scores = calculate_caption_score_and_samples(device, pipe, dataloader, args.inference_steps, args.guidance_scale, args.seed, id_to_char, char_to_id, tile_descriptors, args.describe_absence, output=True, height=common_settings.MARIO_HEIGHT, width=common_settings.MARIO_WIDTH)
194
- scores, avg_score, std_dev_score, min_score, max_score, median_score = statistics_of_captions(perm_caption, dataloader, compare_all_scores, pipe, device, id_to_char, char_to_id, tile_descriptors, num_tiles)
195
-
196
- if args.save_as_json:
197
- result_entry = {
198
- "Caption": one_caption,
199
- "Average score for all permutations": avg_score,
200
- "Standard deviation": std_dev_score,
201
- "Minimum score": min_score,
202
- "Maximum score": max_score,
203
- "Median score": median_score
204
- #"samples": all_samples[i].tolist() if hasattr(all_samples, "__getitem__") else None,
205
- #"prompt": all_prompts[i] if i < len(all_prompts) else "N/A"
206
- }
207
- f.write(json.dumps(result_entry) + "\n")
208
-
209
- all_avg_scores.append(avg_score)
210
-
211
- #scores, avg_score, std_dev_score, min_score, max_score, median_score = statistics_of_captions(perm_caption, dataloader, compare_all_scores, pipe, device, id_to_char, char_to_id, tile_descriptors, num_tiles)
212
- for score in enumerate(scores):
213
- all_scores.append(score)
214
- all_std_dev_scores.append(std_dev_score)
215
- all_min_scores.append(min_score)
216
- all_max_scores.append(max_score)
217
- all_median_scores.append(median_score)
218
- if (count % 10) == 0:
219
- f.flush() # Ensure each result is written immediately
220
- os.fsync(f.fileno()) # Ensure file is flushed to disk
221
- count = count + 1
222
-
223
- print(f"\nAverage score across all captions: {avg_score:.4f}")
224
-
225
-
226
-
227
- visualizations_dir = os.path.join(os.path.dirname(__file__), "visualizations")
228
- if args.caption is not None or "":
229
- caption_folder = args.caption.replace(" ", "_").replace(".", "_")
230
- output_directory = os.path.join(visualizations_dir, caption_folder)
231
-
232
- visualize_samples(
233
- all_samples,
234
- output_dir=output_directory,
235
- prompts=all_prompts[0] if all_prompts else "No prompts available"
236
- )
237
- print(f"\nVisualizations saved to: {output_directory}")
238
-
239
-
240
- print("\nAll samples shape:", all_samples.shape)
241
- print("\nAll prompts:", all_prompts)
242
-
243
- all_avg_score = np.mean(all_avg_scores)
244
- all_std_dev_score = np.std(all_std_dev_scores)
245
- all_min_score = np.min(all_min_scores)
246
- all_max_score = np.max(all_max_scores)
247
- all_median_score = np.median(all_median_scores)
248
-
249
- if args.save_as_json:
250
- output_jsonl_path = os.path.join(args.output_dir, "evaluation_caption_order_results.jsonl")
251
- with open(output_jsonl_path, "w") as f:
252
- if isinstance(caption, list) or (args.caption is None or args.caption == ""):
253
- # Multiple captions (permuted)
254
- for i, score in enumerate(all_avg_scores):
255
- result_entry = {
256
- "Caption": caption[i] if i < len(caption) else "N/A",
257
- "Average score for all permutations": score,
258
- #"samples": all_samples[i].tolist() if hasattr(all_samples, "__getitem__") else None,
259
- #"prompt": all_prompts[i] if i < len(all_prompts) else "N/A"
260
- }
261
- f.write(json.dumps(result_entry) + "\n")
262
- else:
263
- # Single caption
264
- result_entry = {
265
- "caption": caption,
266
- "avg_score": avg_score,
267
- "samples": all_samples.tolist(),
268
- "prompts": all_prompts
269
- }
270
- f.write(json.dumps(result_entry) + "\n")
271
-
272
- results = {
273
-
274
- "Scores of all captions": {
275
- "Scores": all_scores,
276
- "Number of captions": len(all_scores),
277
- "Average of all permutations": all_avg_score,
278
- "Standard deviation of all permutations": all_std_dev_score,
279
- "Min score of all permutations": all_min_score,
280
- "Max score of all permutations": all_max_score,
281
- "Median score of all permutations": all_median_score
282
- },
283
- }
284
- json.dump(results, f, indent=4)
285
-
286
- print(f"Results saved to {output_jsonl_path}")
287
- if __name__ == "__main__":
288
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
captions/util.py DELETED
@@ -1,216 +0,0 @@
1
- import json
2
- import sys
3
- import os
4
- from collections import Counter
5
-
6
- # This file contains utility functions for analyzing and describing levels in both Lode Runner and Super Mario Bros.
7
-
8
- # Could define these via the command line, but for now they are hardcoded
9
- coarse_locations = True
10
- coarse_counts = True
11
- pluralize = True
12
- give_staircase_lengths = False
13
-
14
- def describe_size(count):
15
- if count <= 4: return "small"
16
- else: return "big"
17
-
18
- def describe_quantity(count):
19
- if count == 0: return "no"
20
- elif count == 1: return "one"
21
- elif count == 2: return "two"
22
- elif count < 5: return "a few"
23
- elif count < 10: return "several"
24
- else: return "many"
25
-
26
- def get_tile_descriptors(tileset):
27
- """Creates a mapping from tile character to its list of descriptors."""
28
- result = {char: set(attrs) for char, attrs in tileset["tiles"].items()}
29
- # Fake tiles. Should these contain anything? Note that code elsewhere expects everything to be passable or solid
30
- result["!"] = {"passable"}
31
- result["*"] = {"passable"}
32
- return result
33
-
34
- def analyze_floor(scene, id_to_char, tile_descriptors, describe_absence):
35
- """Analyzes the last row of the 32x32 scene and generates a floor description."""
36
- WIDTH = len(scene[0])
37
- last_row = scene[-1] # The FLOOR row of the scene
38
- solid_count = sum(
39
- 1 for tile in last_row
40
- if tile in id_to_char and (
41
- "solid" in tile_descriptors.get(id_to_char[tile], []) or
42
- "diggable" in tile_descriptors.get(id_to_char[tile], [])
43
- )
44
- )
45
- passable_count = sum(
46
- 1 for tile in last_row if "passable" in tile_descriptors.get(id_to_char[tile], [])
47
- )
48
-
49
- if solid_count == WIDTH:
50
- return "full floor"
51
- elif passable_count == WIDTH:
52
- if describe_absence:
53
- return "no floor"
54
- else:
55
- return ""
56
- elif solid_count > passable_count:
57
- # Count contiguous groups of passable tiles
58
- gaps = 0
59
- in_gap = False
60
- for tile in last_row:
61
- # Enemies are also a gap since they immediately fall into the gap
62
- if "passable" in tile_descriptors.get(id_to_char[tile], []) or "enemy" in tile_descriptors.get(id_to_char[tile], []):
63
- if not in_gap:
64
- gaps += 1
65
- in_gap = True
66
- elif "solid" in tile_descriptors.get(id_to_char[tile], []):
67
- in_gap = False
68
- else:
69
- print("error")
70
- print(tile)
71
- print(id_to_char[tile])
72
- print(tile_descriptors)
73
- print(tile_descriptors.get(id_to_char[tile], []))
74
- raise ValueError("Every tile should be passable, solid, or enemy")
75
- return f"floor with {describe_quantity(gaps) if coarse_counts else gaps} gap" + ("s" if pluralize and gaps != 1 else "")
76
- else:
77
- # Count contiguous groups of solid tiles
78
- chunks = 0
79
- in_chunk = False
80
- for tile in last_row:
81
- if "solid" in tile_descriptors.get(id_to_char[tile], []):
82
- if not in_chunk:
83
- chunks += 1
84
- in_chunk = True
85
- elif "passable" in tile_descriptors.get(id_to_char[tile], []) or "enemy" in tile_descriptors.get(id_to_char[tile], []):
86
- in_chunk = False
87
- else:
88
- print("error")
89
- print(tile)
90
- print(tile_descriptors)
91
- print(tile_descriptors.get(tile, []))
92
- raise ValueError("Every tile should be either passable or solid")
93
- return f"giant gap with {describe_quantity(chunks) if coarse_counts else chunks} chunk"+("s" if pluralize and chunks != 1 else "")+" of floor"
94
-
95
- def count_in_scene(scene, tiles, exclude=set()):
96
- """ counts standalone tiles, unless they are in the exclude set """
97
- count = 0
98
- for r, row in enumerate(scene):
99
- for c, t in enumerate(row):
100
- #if exclude and t in tiles: print(r,c, exclude)
101
- if (r,c) not in exclude and t in tiles:
102
- #if exclude: print((r,t), exclude, (r,t) in exclude)
103
- count += 1
104
- #if exclude: print(tiles, exclude, count)
105
- return count
106
-
107
- def count_caption_phrase(scene, tiles, name, names, offset = 0, describe_absence=False, exclude=set()):
108
- """ offset modifies count used in caption """
109
- count = offset + count_in_scene(scene, tiles, exclude)
110
- #if name == "loose block": print("count", count)
111
- if count > 0:
112
- return f" {describe_quantity(count) if coarse_counts else count} " + (names if pluralize and count > 1 else name) + "."
113
- elif describe_absence:
114
- return f" no {names}."
115
- else:
116
- return ""
117
-
118
- def in_column(scene, x, tile):
119
- for row in scene:
120
- if row[x] == tile:
121
- return True
122
-
123
- return False
124
-
125
- def analyze_ceiling(scene, id_to_char, tile_descriptors, describe_absence, ceiling_row = 1):
126
- """
127
- Analyzes ceiling row (0-based index) to detect a ceiling.
128
- Returns a caption phrase or an empty string if no ceiling is detected.
129
- """
130
- WIDTH = len(scene[0])
131
-
132
- row = scene[ceiling_row]
133
- solid_count = sum(1 for tile in row if "solid" in tile_descriptors.get(id_to_char[tile], []))
134
-
135
- if solid_count == WIDTH:
136
- return " full ceiling."
137
- elif solid_count > WIDTH//2:
138
- # Count contiguous gaps of passable tiles
139
- gaps = 0
140
- in_gap = False
141
- for tile in row:
142
- # Enemies are also a gap since they immediately fall into the gap, but they are marked as "moving" and not "passable"
143
- if "passable" in tile_descriptors.get(id_to_char[tile], []) or "moving" in tile_descriptors.get(id_to_char[tile], []):
144
- if not in_gap:
145
- gaps += 1
146
- in_gap = True
147
- else:
148
- in_gap = False
149
- result = f" ceiling with {describe_quantity(gaps) if coarse_counts else gaps} gap" + ("s" if pluralize and gaps != 1 else "") + "."
150
-
151
- # Adding the "moving" check should make this code unnecessary
152
- #if result == ' ceiling with no gaps.':
153
- # print("This should not happen: ceiling with no gaps")
154
- # print("ceiling_row:", scene[ceiling_row])
155
- # result = " full ceiling."
156
-
157
- return result
158
- elif describe_absence:
159
- return " no ceiling."
160
- else:
161
- return "" # Not enough solid tiles for a ceiling
162
-
163
- def extract_tileset(tileset_path):
164
- # Load tileset
165
- with open(tileset_path, "r") as f:
166
- tileset = json.load(f)
167
- #print(f"tileset: {tileset}")
168
- tile_chars = sorted(tileset['tiles'].keys())
169
- # Wiggle room for the tileset to be a bit more flexible.
170
- # However, this requires me to add some bogus tiles to the list.
171
- # tile_chars.append('!')
172
- # tile_chars.append('*')
173
- #print(f"tile_chars: {tile_chars}")
174
- id_to_char = {idx: char for idx, char in enumerate(tile_chars)}
175
- #print(f"id_to_char: {id_to_char}")
176
- char_to_id = {char: idx for idx, char in enumerate(tile_chars)}
177
- #print(f"char_to_id: {char_to_id}")
178
- tile_descriptors = get_tile_descriptors(tileset)
179
- #print(f"tile_descriptors: {tile_descriptors}")
180
-
181
- return tile_chars, id_to_char, char_to_id, tile_descriptors
182
-
183
- def flood_fill(scene, visited, start_row, start_col, id_to_char, tile_descriptors, excluded, pipes=False, target_descriptor=None):
184
- stack = [(start_row, start_col)]
185
- structure = []
186
-
187
- while stack:
188
- row, col = stack.pop()
189
- if (row, col) in visited or (row, col) in excluded:
190
- continue
191
- tile = scene[row][col]
192
- descriptors = tile_descriptors.get(id_to_char[tile], [])
193
- # Use target_descriptor if provided, otherwise default to old solid/pipe logic
194
- if target_descriptor is not None:
195
- if target_descriptor not in descriptors:
196
- continue
197
- else:
198
- if "solid" not in descriptors or (not pipes and "pipe" in descriptors) or (pipes and "pipe" not in descriptors):
199
- continue
200
-
201
- visited.add((row, col))
202
- structure.append((row, col))
203
-
204
- # Check neighbors
205
- for d_row, d_col in [(-1,0), (1,0), (0,-1), (0,1)]:
206
- # Weird special case for adjacent pipes
207
- if (id_to_char[tile] == '>' or id_to_char[tile] == ']') and d_col == 1: # if on the right edge of a pipe
208
- continue # Don't go right if on the right edge of a pipe
209
- if (id_to_char[tile] == '<' or id_to_char[tile] == '[') and d_col == -1: # if on the left edge of a pipe
210
- continue # Don't go left if on the left edge of a pipe
211
-
212
- n_row, n_col = row + d_row, col + d_col
213
- if 0 <= n_row < len(scene) and 0 <= n_col < len(scene[0]):
214
- stack.append((n_row, n_col))
215
-
216
- return structure